File size: 2,378 Bytes
9f3c7b7
65eef23
93798d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65eef23
9f3c7b7
93798d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f3c7b7
93798d6
9f3c7b7
93798d6
 
 
 
 
 
 
 
 
 
42f5f04
93798d6
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from turtle import onclick
import streamlit as st
import pandas as pd

EDIT_ALGS = [
    "MEND: Model editor networks using gradient decomposition [Mitchell et al., 2021]",
    "SERAC: Semi-parametric editing with a retrieval-augmented counterfactual model [Mitchell et al., 2022]",
    "ENN: Editable Neural Networks [Sinitsin et al., 2020]",
    "KE: KnowledgeEditor [De Cao et al., 2020]",
    "Fine-tuning",
    "Lookup Cache"
]

st.title("Model editing demo")

# https://discuss.streamlit.io/t/simple-example-of-persistence-and-waiting-for-input/2111
@st.cache(allow_output_mutation=True)
def Edits():
    return pd.DataFrame([], columns=["Edit input", "Edit label"])

@st.cache(allow_output_mutation=True)
def ModelOutput():
    return pd.DataFrame([], columns=["Input", "Generation", "Edits applied"])

edits = Edits()
model_outputs = ModelOutput()

def reset():
    edits.drop(edits.index, inplace=True)
    model_outputs.drop(edits.index, inplace=True)
    current_edit_alg = str(alg_selector)
    st.write(current_edit_alg)
    ############# Need to reset the model here (and maybe show progress spinner?)

def apply_edit():
    edits.loc[len(edits)] = [str(edit_input), str(edit_label)]

def sample_model():
    model_outputs.loc[len(model_outputs)] = [str(test_input), "blah blah blah", len(edits)]

alg_selector = st.selectbox("Editing algorithm:", EDIT_ALGS, on_change=reset)

st.write("Edits applied so far:")
st.table(edits)
st.button("Reset model", on_click=reset)

st.markdown("***")

col1, col2, col3 = st.columns([3, 2, 1])
with col1:
    edit_input = st.text_input("Edit input:", placeholder="e.g., 'What is the tallest mountain on Earth?'")
with col2:
    edit_label = st.text_input("Edit target:", placeholder="e.g., 'Denali'", help="The desired output of the model for the edit input")
with col3:
    st.markdown("##")
    edit_button = st.button("Apply edit", on_click=apply_edit)

st.markdown("***")

col1, col2 = st.columns([5, 1])
with col1:
    if len(edits) == 0:
        title = "Input to sample from *unedited* model:"
    else:
        title = f"Input to sample from *edited* model:"
    test_input = st.text_input(title, placeholder="e.g., 'What is the earth's tallest mountain?'")
with col2:
    st.markdown("##")
    generate_button = st.button("Generate", on_click=sample_model)

st.write("Model generation history:")
st.table(model_outputs)