from turtle import onclick import streamlit as st import pandas as pd EDIT_ALGS = [ "MEND: Model editor networks using gradient decomposition [Mitchell et al., 2021]", "SERAC: Semi-parametric editing with a retrieval-augmented counterfactual model [Mitchell et al., 2022]", "ENN: Editable Neural Networks [Sinitsin et al., 2020]", "KE: KnowledgeEditor [De Cao et al., 2020]", "Fine-tuning", "Lookup Cache" ] st.title("Model editing demo") # https://discuss.streamlit.io/t/simple-example-of-persistence-and-waiting-for-input/2111 @st.cache(allow_output_mutation=True) def Edits(): return pd.DataFrame([], columns=["Edit input", "Edit label"]) @st.cache(allow_output_mutation=True) def ModelOutput(): return pd.DataFrame([], columns=["Input", "Generation", "Edits applied"]) edits = Edits() model_outputs = ModelOutput() def reset(): edits.drop(edits.index, inplace=True) model_outputs.drop(edits.index, inplace=True) current_edit_alg = str(alg_selector) st.write(current_edit_alg) ############# Need to reset the model here (and maybe show progress spinner?) def apply_edit(): edits.loc[len(edits)] = [str(edit_input), str(edit_label)] def sample_model(): model_outputs.loc[len(model_outputs)] = [str(test_input), "blah blah blah", len(edits)] alg_selector = st.selectbox("Editing algorithm:", EDIT_ALGS, on_change=reset) st.write("Edits applied so far:") st.table(edits) st.button("Reset model", on_click=reset) st.markdown("***") col1, col2, col3 = st.columns([3, 2, 1]) with col1: edit_input = st.text_input("Edit input:", placeholder="e.g., 'What is the tallest mountain on Earth?'") with col2: edit_label = st.text_input("Edit target:", placeholder="e.g., 'Denali'", help="The desired output of the model for the edit input") with col3: st.markdown("##") edit_button = st.button("Apply edit", on_click=apply_edit) st.markdown("***") col1, col2 = st.columns([5, 1]) with col1: if len(edits) == 0: title = "Input to sample from *unedited* model:" else: title = f"Input to sample from *edited* model:" test_input = st.text_input(title, placeholder="e.g., 'What is the earth's tallest mountain?'") with col2: st.markdown("##") generate_button = st.button("Generate", on_click=sample_model) st.write("Model generation history:") st.table(model_outputs)