File size: 6,162 Bytes
65eef23
93798d6
89e27db
a9853a7
8f3eda5
9b78f9c
8f3eda5
e56055d
8f3eda5
 
8335d0c
8f3eda5
93798d6
 
89e27db
 
 
 
902d725
8f3eda5
93798d6
 
8335d0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b78f9c
 
 
 
 
9f3c7b7
b3740e7
 
89e27db
 
a9853a7
8335d0c
93798d6
 
b3740e7
93798d6
a9853a7
 
 
 
 
902d725
93798d6
902d725
9b78f9c
 
 
 
902d725
 
 
 
 
 
 
 
 
 
 
 
8335d0c
bb4bb43
 
9b78f9c
8335d0c
 
 
902d725
 
 
 
 
 
9db84b2
902d725
 
 
 
93798d6
db7fef9
 
 
 
 
 
93798d6
 
b3740e7
93798d6
 
 
 
 
db7fef9
93798d6
db7fef9
93798d6
9f3c7b7
93798d6
9f3c7b7
b3740e7
db7fef9
 
 
93798d6
 
 
 
db7fef9
93798d6
42f5f04
93798d6
b3740e7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import streamlit as st
import pandas as pd
import time
import copy
import importlib
from torch.cuda import is_available as use_cuda

import algs
import config
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import utils


EDIT_ALGS = [
    "MEND: Model editor networks using gradient decomposition",
    "SERAC: Semi-parametric editing with a retrieval-augmented counterfactual model",
    "ENN: Editable neural networks",
    "KE: KnowledgeEditor",
    "FT: Fine-tuning",
    "LU: Lookup Cache",
]

def get_alg_class(alg_abbrv):
    alg_module = importlib.import_module(f"algs.{alg_abbrv.lower()}")
    alg_class = getattr(alg_module, alg_abbrv.upper())
    return alg_class

def load_editable_model(alg_abbrv):
    alg_module = importlib.import_module(f"algs.{alg_abbrv.lower()}")
    alg_class = getattr(alg_module, alg_abbrv.upper())
    st.session_state.config = getattr(config, f"{alg_abbrv.lower()}_config")
    with st.spinner('Loading model...'):
        st.session_state.editable_model = alg_class(
            st.session_state.model,
            st.session_state.config,
            lambda: copy.deepcopy(st.session_state.model),
        ).eval()
        if "archive" in st.session_state.config:
            archive, st.session_state.config.archive = utils.load_archive(str(st.session_state.config.archive))
            print(f"Loading archive from {st.session_state.config.archive}")
            st.session_state.editable_model.load_state_dict(archive["model"])

def generate(ids):
    output_ids = st.session_state.editable_model.generate(input_ids=ids, max_new_tokens=20, min_length=1,
                                                          num_return_sequences=1, num_beams=3)
    return st.session_state.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]

def reset():
    st.session_state.edits.drop(st.session_state.edits.index, inplace=True)
    st.session_state.model_outputs.drop(st.session_state.edits.index, inplace=True)

    selected_alg = st.session_state.alg_selector
    alg_abbrv = selected_alg[:selected_alg.index(":")]
    load_editable_model(alg_abbrv)

def apply_edit():
    st.session_state.edits.loc[len(st.session_state.edits)] = [str(edit_input), str(edit_label)]

    with st.spinner("Editing model..."):
        input_ids = st.session_state.tokenizer(str(edit_input), return_tensors="pt")["input_ids"].to(st.session_state.device)
        label_ids = st.session_state.tokenizer(str(edit_label), return_tensors="pt")["input_ids"].to(st.session_state.device)
        edit_sample = {"input_ids": input_ids, "labels": label_ids}
        st.session_state.editable_model, _ = st.session_state.editable_model.edit(edit_sample, detach_history=True)

def sample_model():
    input_str = str(test_input)
    with st.spinner('Generating completion...'):
        encoding = st.session_state.tokenizer(input_str, return_tensors="pt")
        ids = encoding["input_ids"].to(st.session_state.device)
        model_output = generate(ids)
    n_edits = len(st.session_state.edits)
    alg_name = st.session_state.alg_selector
    alg_abbrv = alg_name[:alg_name.index(":")]
    st.session_state.model_outputs.loc[len(st.session_state.model_outputs)] = [input_str, model_output, n_edits, alg_abbrv]

################################
#### Backend initialization ####
################################
if "init" not in st.session_state:
    st.session_state.edits = pd.DataFrame([], columns=["Edit input", "Edit label"])
    st.session_state.model_outputs = pd.DataFrame([], columns=["Input", "Output", "N edits", "Alg"])
    st.session_state.init = True
    st.session_state.device = "cpu"  # "cuda" if use_cuda() else "cpu"
    with st.spinner('Loading model...'):
        st.session_state.tokenizer = AutoTokenizer.from_pretrained("google/t5-large-ssm-nq")
        st.session_state.model = AutoModelForSeq2SeqLM.from_pretrained("google/t5-large-ssm-nq").to(st.session_state.device).eval()
    # There is a "Loading model..." spinner in load_editable_model
    alg_abbrv = "MEND"  # Default initial alg of dropdown selector
    load_editable_model(alg_abbrv)

########################
#### Interface code ####
########################

st.title("Language Model Editing")
st.markdown("**Note: this HF space is currently under development and doesn't actually work yet!**")
st.markdown("The goal of this demo is to give you a sense of the *abilities* and *limitations* of existing methods for **editing** pre-trained language models. **Model editing** algorithms use a single input-output pair to update a pre-trained model's behavior for that input (and ideally, related inputs).")
st.markdown("This demo uses a [T5-large](https://huggingface.co/google/t5-large-ssm-nq) model fine-tuned on [Natural Questions](https://arxiv.org/pdf/2002.08910.pdf) as the base pre-trained model.")
st.write("You can choose from a variety of algorithms for model editing in the dropdown below. At the bottom of the page, you can query the model for whatever input you want before/after editing.")
st.markdown("***")

col1, col2 = st.columns([5,1])
with col1:
    alg_selector = st.selectbox("Editing algorithm:", EDIT_ALGS, key="alg_selector", on_change=reset)
with col2:
    st.text("ㅤ")
    st.button("Clear edits", on_click=reset)

st.write("Edits applied so far:")
st.table(st.session_state.edits)

col1, col2, col3 = st.columns([3, 2, 1])
with col1:
    edit_input = st.text_input("Edit input:", placeholder="e.g., 'What is the tallest mountain on Earth?'")
with col2:
    edit_label = st.text_input("Edit target:", placeholder="e.g., 'Denali'")
with col3:
    st.text("ㅤ")
    edit_button = st.button("Apply edit", on_click=apply_edit)

st.markdown("***")

if len(st.session_state.edits) == 0:
    title = "Input to sample from *unedited* model:"
else:
    title = f"Input to sample from *edited* model:"
col1, col2 = st.columns([5, 1])
with col1:
    test_input = st.text_input(title, placeholder="e.g., 'What is the earth's tallest mountain?'")
with col2:
    st.text("ㅤ")
    generate_button = st.button("Generate", on_click=sample_model)

st.write("Model generation history:")
st.table(st.session_state.model_outputs)