File size: 2,432 Bytes
72cfe15
f0b559a
 
 
 
 
 
 
 
 
 
72cfe15
 
 
f0b559a
 
 
72cfe15
f0b559a
72cfe15
 
 
f0b559a
72cfe15
 
 
f0b559a
 
 
275482b
f0b559a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import streamlit as st

import numpy as np
import argparse
import itertools
import time
import ast
import re
from tracr.compiler import compiling
from typing import get_args
import inspect
import pickle
import base64

from abstract_syntax_tree import *
from python_embedded_rasp import *
from rasp_synthesizer import *

# HELPER FUNCTIONS
def download_model(model):
    output_model = pickle.dumps(model)
    b64 = base64.b64encode(output_model).decode()
    href = f'<a href="data:file/output_model;base64,{b64}" download="model_params.pkl">Download Haiku Model Parameters in a .pkl File</a>'
    st.markdown(href, unsafe_allow_html=True)


# APP DRIVER CODE
st.title("Bottom Up Synthesis for RASP")

max_weight = st.slider("Choose the maximum program weight to search for (~ size of transformer)", 2, 20, 15)

default_example = "[[['h', 'e', 'l', 'l', 'o'], [1,1,2,2,1]]]"
example_text = st.text_input(label = "Provide Input and Output Examples", value = default_example)

inputs, outs = analyze_examples(example_text)
examples = list(zip(inputs, outs))
st.write("Received the following input and output examples:")
st.write(examples)
max_seq_len = 0
for i in inputs:
    max_seq_len = max(len(i), max_seq_len)
vocab = get_vocabulary(examples)

st.subheader("Synthesis Configuration")
st.write("Running synthesizer with")
st.write("Vocab: {}".format(vocab))
st.write("Max sequence length: {}".format(max_seq_len))
st.write("Max weight: {}".format(max_weight))

program, approx_programs = run_synthesizer(examples, max_weight)

st.subheader("Synthesis Results:")
st.caption("May take a while.")
if program:
    algorithm = program.to_python()
    
    bos = "BOS"
    model = compiling.compile_rasp_to_model(
        algorithm,
        vocab=vocab,
        max_seq_len=max_seq_len,
        compiler_bos=bos,
    )
    
    
    def extract_layer_number(s):
        match = re.search(r'layer_(\d+)', s)
        if match:
            return int(match.group(1)) + 1
        else:
            return None
    
    layer_num = extract_layer_number(list(model.params.keys())[-1])
    st.write(f"The following program has been compiled to a transformer with {layer_num} layer(s):")
    st.write(program.str())
    
    st.write("Here is a model download link: ")
    hk_model = model.params
    download_model(hk_model)
else:
    st.write("No exact program found but here is some brainstorming fodder for you: {}".format(approx_programs))