model-editing / app.py
Eric Mitchell
Lots of changes.
93798d6
raw history blame
No virus
2.38 kB
from turtle import onclick
import streamlit as st
import pandas as pd
EDIT_ALGS = [
"MEND: Model editor networks using gradient decomposition [Mitchell et al., 2021]",
"SERAC: Semi-parametric editing with a retrieval-augmented counterfactual model [Mitchell et al., 2022]",
"ENN: Editable Neural Networks [Sinitsin et al., 2020]",
"KE: KnowledgeEditor [De Cao et al., 2020]",
"Fine-tuning",
"Lookup Cache"
]
st.title("Model editing demo")
# https://discuss.streamlit.io/t/simple-example-of-persistence-and-waiting-for-input/2111
@st.cache(allow_output_mutation=True)
def Edits():
return pd.DataFrame([], columns=["Edit input", "Edit label"])
@st.cache(allow_output_mutation=True)
def ModelOutput():
return pd.DataFrame([], columns=["Input", "Generation", "Edits applied"])
edits = Edits()
model_outputs = ModelOutput()
def reset():
edits.drop(edits.index, inplace=True)
model_outputs.drop(edits.index, inplace=True)
current_edit_alg = str(alg_selector)
st.write(current_edit_alg)
############# Need to reset the model here (and maybe show progress spinner?)
def apply_edit():
edits.loc[len(edits)] = [str(edit_input), str(edit_label)]
def sample_model():
model_outputs.loc[len(model_outputs)] = [str(test_input), "blah blah blah", len(edits)]
alg_selector = st.selectbox("Editing algorithm:", EDIT_ALGS, on_change=reset)
st.write("Edits applied so far:")
st.table(edits)
st.button("Reset model", on_click=reset)
st.markdown("***")
col1, col2, col3 = st.columns([3, 2, 1])
with col1:
edit_input = st.text_input("Edit input:", placeholder="e.g., 'What is the tallest mountain on Earth?'")
with col2:
edit_label = st.text_input("Edit target:", placeholder="e.g., 'Denali'", help="The desired output of the model for the edit input")
with col3:
st.markdown("##")
edit_button = st.button("Apply edit", on_click=apply_edit)
st.markdown("***")
col1, col2 = st.columns([5, 1])
with col1:
if len(edits) == 0:
title = "Input to sample from *unedited* model:"
else:
title = f"Input to sample from *edited* model:"
test_input = st.text_input(title, placeholder="e.g., 'What is the earth's tallest mountain?'")
with col2:
st.markdown("##")
generate_button = st.button("Generate", on_click=sample_model)
st.write("Model generation history:")
st.table(model_outputs)