Spaces:
Sleeping
Sleeping
import streamlit as st | |
import numpy as np | |
import argparse | |
import itertools | |
import time | |
import ast | |
import re | |
from tracr.compiler import compiling | |
from typing import get_args | |
import inspect | |
import pickle | |
import base64 | |
from abstract_syntax_tree import * | |
from python_embedded_rasp import * | |
from rasp_synthesizer import * | |
# HELPER FUNCTIONS | |
def download_model(model): | |
output_model = pickle.dumps(model) | |
b64 = base64.b64encode(output_model).decode() | |
href = f'<a href="data:file/output_model;base64,{b64}" download="model_params.pkl">Download Haiku Model Parameters in a .pkl File</a>' | |
st.markdown(href, unsafe_allow_html=True) | |
# APP DRIVER CODE | |
st.title("Bottom Up Synthesis for RASP") | |
max_weight = st.slider("Choose the maximum program weight to search for (~ size of transformer)", 2, 20, 15) | |
default_example = "[[['h', 'e', 'l', 'l', 'o'], [1,1,2,2,1]]]" | |
example_text = st.text_input(label = "Provide Input and Output Examples", value = default_example) | |
inputs, outs = analyze_examples(example_text) | |
examples = list(zip(inputs, outs)) | |
st.write("Received the following input and output examples:") | |
st.write(examples) | |
max_seq_len = 0 | |
for i in inputs: | |
max_seq_len = max(len(i), max_seq_len) | |
vocab = get_vocabulary(examples) | |
st.subheader("Synthesis Configuration") | |
st.write("Running synthesizer with") | |
st.write("Vocab: {}".format(vocab)) | |
st.write("Max sequence length: {}".format(max_seq_len)) | |
st.write("Max weight: {}".format(max_weight)) | |
program, approx_programs = run_synthesizer(examples, max_weight) | |
st.subheader("Synthesis Results:") | |
st.caption("May take a while.") | |
if program: | |
algorithm = program.to_python() | |
bos = "BOS" | |
model = compiling.compile_rasp_to_model( | |
algorithm, | |
vocab=vocab, | |
max_seq_len=max_seq_len, | |
compiler_bos=bos, | |
) | |
def extract_layer_number(s): | |
match = re.search(r'layer_(\d+)', s) | |
if match: | |
return int(match.group(1)) + 1 | |
else: | |
return None | |
layer_num = extract_layer_number(list(model.params.keys())[-1]) | |
st.write(f"The following program has been compiled to a transformer with {layer_num} layer(s):") | |
st.write(program.str()) | |
st.write("Here is a model download link: ") | |
hk_model = model.params | |
download_model(hk_model) | |
else: | |
st.write("No exact program found but here is some brainstorming fodder for you: {}".format(approx_programs)) |