johnhew lora-x commited on
Commit
ffb38f8
0 Parent(s):

Duplicate from lora-x/Backpack

Browse files

Co-authored-by: Lora Xie <lora-x@users.noreply.huggingface.co>

Files changed (7) hide show
  1. .gitattributes +34 -0
  2. README.md +13 -0
  3. app.py +550 -0
  4. requirements.txt +4 -0
  5. senses/all_vecs_mtx.pt +3 -0
  6. senses/lm_head.pt +3 -0
  7. senses/use_senses.py +44 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Backpack
3
+ emoji: 🏃
4
+ colorFrom: green
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 3.24.1
8
+ app_file: app.py
9
+ pinned: false
10
+ duplicated_from: lora-x/Backpack
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import transformers
3
+ from transformers import AutoModelForCausalLM
4
+ import pandas as pd
5
+ import gradio as gr
6
+
7
+ # Build model & get some layers
8
+ tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2')
9
+ m = AutoModelForCausalLM.from_pretrained("lora-x/backpack-gpt2", trust_remote_code=True)
10
+ m.eval()
11
+
12
+ lm_head = m.get_lm_head() # (V, d)
13
+ word_embeddings = m.backpack.get_word_embeddings() # (V, d)
14
+ sense_network = m.backpack.get_sense_network() # (V, nv, d)
15
+ num_senses = m.backpack.get_num_senses()
16
+ sense_names = [i for i in range(num_senses)]
17
+
18
+ """
19
+ Single token sense lookup
20
+ """
21
+ def visualize_word(word, count=10, remove_space=False):
22
+
23
+ if not remove_space:
24
+ word = ' ' + word
25
+ print(f"Looking up word '{word}'...")
26
+
27
+ token_ids = tokenizer(word)['input_ids']
28
+ tokens = [tokenizer.decode(token_id) for token_id in token_ids]
29
+ tokens = ", ".join(tokens) # display tokenization for user
30
+ print(f"Tokenized as: {tokens}")
31
+ # look up sense vectors only for the first token
32
+ # contents = vecs[token_ids[0]] # torch.Size([16, 768])
33
+ sense_input_embeds = word_embeddings(torch.tensor([token_ids[0]]).long().unsqueeze(0)) # (bs=1, s=1, d), sense_network expects bs dim
34
+ senses = sense_network(sense_input_embeds) # -> (bs=1, nv, s=1, d)
35
+ senses = torch.squeeze(senses) # (nv, s=1, d)
36
+
37
+ # for pos and neg respectively, create a list (for each sense) of list (top k) of tuples (word, logit)
38
+ pos_word_lists = []
39
+ neg_word_lists = []
40
+ sense_names = [] # column header
41
+ for i in range(senses.shape[0]):
42
+ logits = lm_head(senses[i,:])
43
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
44
+ sense_names.append('sense {}'.format(i))
45
+
46
+ pos_sorted_words = [tokenizer.decode(sorted_indices[j]) for j in range(count)]
47
+ pos_sorted_logits = [sorted_logits[j].item() for j in range(count)]
48
+ pos_word_lists.append(list(zip(pos_sorted_words, pos_sorted_logits)))
49
+
50
+ neg_sorted_words = [tokenizer.decode(sorted_indices[-j-1]) for j in range(count)]
51
+ neg_sorted_logits = [sorted_logits[-j-1].item() for j in range(count)]
52
+ neg_word_lists.append(list(zip(neg_sorted_words, neg_sorted_logits)))
53
+
54
+ def create_dataframe(word_lists, sense_names, count):
55
+ data = dict(zip(sense_names, word_lists))
56
+ df = pd.DataFrame(index=[i for i in range(count)],
57
+ columns=list(data.keys()))
58
+ for prop, word_list in data.items():
59
+ for i, word_pair in enumerate(word_list):
60
+ cell_value = "space ({:.2f})".format(word_pair[1])
61
+ cell_value = "{} ({:.2f})".format(word_pair[0], word_pair[1])
62
+ df.at[i, prop] = cell_value
63
+ return df
64
+
65
+ pos_df = create_dataframe(pos_word_lists, sense_names, count)
66
+ neg_df = create_dataframe(neg_word_lists, sense_names, count)
67
+
68
+ return pos_df, neg_df, tokens
69
+
70
+ """
71
+ Returns:
72
+ - tokens: the tokenization of the input sentence, also used as options to choose from for get_token_contextual_weights
73
+ - top_k_words_df: a dataframe of the top k words predicted by the model
74
+ - length: of the input sentence, stored as a gr.State variable so other methods can find the
75
+ contextualization weights for the *last* token that's needed
76
+ - contextualization_weights: gr.State variable, stores the contextualization weights for the input sentence
77
+ """
78
+ def predict_next_word (sentence, top_k = 5, contextualization_weights = None):
79
+
80
+ if sentence == "":
81
+ return None, None, None, None
82
+
83
+ # For better tokenization, by default, adds a space at the beginning of the sentence if it doesn't already have one
84
+ # and remove trailing space
85
+ sentence = sentence.strip()
86
+ if sentence[0] != ' ':
87
+ sentence = ' ' + sentence
88
+ print(f"Sentence: '{sentence}'")
89
+
90
+ # Make input, keeping track of original length
91
+ token_ids = tokenizer(sentence)['input_ids']
92
+ tokens = [[tokenizer.decode(token_id) for token_id in token_ids]] # a list of a single list because used as dataframe
93
+ length = len(token_ids)
94
+ inp = torch.zeros((1,512)).long()
95
+ inp[0,:length] = torch.tensor(token_ids).long()
96
+
97
+ # Get output at correct index
98
+ if contextualization_weights is None:
99
+ print("contextualization_weights IS None, freshly computing contextualization_weights")
100
+ output = m(inp)
101
+ logits, contextualization_weights = output.logits[0,length-1,:], output.contextualization
102
+ # Store contextualization weights and return it as a gr.State var for use by get_token_contextual_weights
103
+ else:
104
+ print("contextualization_weights is NOT None, using passed in contextualization_weights")
105
+ output = m.run_with_custom_contextualization(inp, contextualization_weights)
106
+ logits = output.logits[0,length-1,:]
107
+ probs = logits.softmax(dim=-1) # probs over next word
108
+ probs, indices = torch.sort(probs, descending=True)
109
+ top_k_words = [(tokenizer.decode(indices[i]), round(probs[i].item(), 3)) for i in range(top_k)]
110
+ top_k_words_df = pd.DataFrame(top_k_words, columns=['word', 'probability'], index=range(1, top_k+1))
111
+
112
+ top_k_words_df = top_k_words_df.T
113
+
114
+ print(top_k_words_df)
115
+
116
+ return tokens, top_k_words_df, length, contextualization_weights
117
+
118
+
119
+ """
120
+ Returns a dataframe of senses with weights for the selected token.
121
+
122
+ Args:
123
+ contextualization_weights: a gr.State variable that stores the contextualization weights for the input sentence.
124
+ length: length of the input sentence, used to get the contextualization weights for the last token
125
+ token: the selected token
126
+ token_index: the index of the selected token in the input sentence
127
+ pos_count: how many top positive words to display for each sense
128
+ neg_count: how many top negative words to display for each sense
129
+ """
130
+ def get_token_contextual_weights (contextualization_weights, length, token, token_index, pos_count = 5, neg_count = 3):
131
+ print(">>>>>in get_token_contextual_weights")
132
+ print(f"Selected {token_index}th token: {token}")
133
+
134
+ # get contextualization weights for the selected token
135
+ # Only care about the weights for the last word, since that's what contributes to the output
136
+ token_contextualization_weights = contextualization_weights[0, :, length-1, token_index]
137
+ token_contextualization_weights_list = [round(x, 3) for x in token_contextualization_weights.tolist()]
138
+
139
+ # get sense vectors of the selected token
140
+ token_ids = tokenizer(token)['input_ids'] # keep as a list bc sense_network expects s dim
141
+ sense_input_embeds = word_embeddings(torch.tensor(token_ids).long().unsqueeze(0)) # (bs=1, s=1, d), sense_network expects bs dim
142
+ senses = sense_network(sense_input_embeds) # -> (bs=1, nv, s=1, d)
143
+ senses = torch.squeeze(senses) # (nv, s=1, d)
144
+
145
+ # build dataframe
146
+ pos_dfs, neg_dfs = [], []
147
+
148
+ for i in range(num_senses):
149
+ logits = lm_head(senses[i,:]) # (vocab,) [768, 50257] -> [50257]
150
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
151
+
152
+ pos_sorted_words = [tokenizer.decode(sorted_indices[j]) for j in range(pos_count)]
153
+ pos_df = pd.DataFrame(pos_sorted_words, columns=["Sense {}".format(i)])
154
+ pos_dfs.append(pos_df)
155
+
156
+ neg_sorted_words = [tokenizer.decode(sorted_indices[-j-1]) for j in range(neg_count)]
157
+ neg_df = pd.DataFrame(neg_sorted_words, columns=["Top Negative"])
158
+ neg_dfs.append(neg_df)
159
+
160
+ sense0words, sense1words, sense2words, sense3words, sense4words, sense5words, \
161
+ sense6words, sense7words, sense8words, sense9words, sense10words, sense11words, \
162
+ sense12words, sense13words, sense14words, sense15words = pos_dfs
163
+
164
+ sense0negwords, sense1negwords, sense2negwords, sense3negwords, sense4negwords, sense5negwords, \
165
+ sense6negwords, sense7negwords, sense8negwords, sense9negwords, sense10negwords, sense11negwords, \
166
+ sense12negwords, sense13negwords, sense14negwords, sense15negwords = neg_dfs
167
+
168
+ sense0slider, sense1slider, sense2slider, sense3slider, sense4slider, sense5slider, \
169
+ sense6slider, sense7slider, sense8slider, sense9slider, sense10slider, sense11slider, \
170
+ sense12slider, sense13slider, sense14slider, sense15slider = token_contextualization_weights_list
171
+
172
+ return token, token_index, \
173
+ sense0words, sense1words, sense2words, sense3words, sense4words, sense5words, sense6words, sense7words, \
174
+ sense8words, sense9words, sense10words, sense11words, sense12words, sense13words, sense14words, sense15words, \
175
+ sense0negwords, sense1negwords, sense2negwords, sense3negwords, sense4negwords, sense5negwords, sense6negwords, sense7negwords, \
176
+ sense8negwords, sense9negwords, sense10negwords, sense11negwords, sense12negwords, sense13negwords, sense14negwords, sense15negwords, \
177
+ sense0slider, sense1slider, sense2slider, sense3slider, sense4slider, sense5slider, sense6slider, sense7slider, \
178
+ sense8slider, sense9slider, sense10slider, sense11slider, sense12slider, sense13slider, sense14slider, sense15slider
179
+
180
+ """
181
+ Wrapper for when the user selects a new token in the tokens dataframe.
182
+ Converts `evt` (the selected token) to `token` and `token_index` which are used by get_token_contextual_weights.
183
+ """
184
+ def new_token_contextual_weights (contextualization_weights, length, evt: gr.SelectData, pos_count = 5, neg_count = 3):
185
+ print(">>>>>in new_token_contextual_weights")
186
+ token_index = evt.index[1] # selected token is the token_index-th token in the sentence
187
+ token = evt.value
188
+ if not token:
189
+ return None, None, \
190
+ None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, \
191
+ None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, \
192
+ None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
193
+ return get_token_contextual_weights (contextualization_weights, length, token, token_index, pos_count, neg_count)
194
+
195
+ def change_sense0_weight(contextualization_weights, length, token_index, new_weight):
196
+ contextualization_weights[0, 0, length-1, token_index] = new_weight
197
+ return contextualization_weights
198
+ def change_sense1_weight(contextualization_weights, length, token_index, new_weight):
199
+ contextualization_weights[0, 1, length-1, token_index] = new_weight
200
+ return contextualization_weights
201
+ def change_sense2_weight(contextualization_weights, length, token_index, new_weight):
202
+ contextualization_weights[0, 2, length-1, token_index] = new_weight
203
+ return contextualization_weights
204
+ def change_sense3_weight(contextualization_weights, length, token_index, new_weight):
205
+ contextualization_weights[0, 3, length-1, token_index] = new_weight
206
+ return contextualization_weights
207
+ def change_sense4_weight(contextualization_weights, length, token_index, new_weight):
208
+ contextualization_weights[0, 4, length-1, token_index] = new_weight
209
+ return contextualization_weights
210
+ def change_sense5_weight(contextualization_weights, length, token_index, new_weight):
211
+ contextualization_weights[0, 5, length-1, token_index] = new_weight
212
+ return contextualization_weights
213
+ def change_sense6_weight(contextualization_weights, length, token_index, new_weight):
214
+ contextualization_weights[0, 6, length-1, token_index] = new_weight
215
+ return contextualization_weights
216
+ def change_sense7_weight(contextualization_weights, length, token_index, new_weight):
217
+ contextualization_weights[0, 7, length-1, token_index] = new_weight
218
+ return contextualization_weights
219
+ def change_sense8_weight(contextualization_weights, length, token_index, new_weight):
220
+ contextualization_weights[0, 8, length-1, token_index] = new_weight
221
+ return contextualization_weights
222
+ def change_sense9_weight(contextualization_weights, length, token_index, new_weight):
223
+ contextualization_weights[0, 9, length-1, token_index] = new_weight
224
+ return contextualization_weights
225
+ def change_sense10_weight(contextualization_weights, length, token_index, new_weight):
226
+ contextualization_weights[0, 10, length-1, token_index] = new_weight
227
+ return contextualization_weights
228
+ def change_sense11_weight(contextualization_weights, length, token_index, new_weight):
229
+ contextualization_weights[0, 11, length-1, token_index] = new_weight
230
+ return contextualization_weights
231
+ def change_sense12_weight(contextualization_weights, length, token_index, new_weight):
232
+ contextualization_weights[0, 12, length-1, token_index] = new_weight
233
+ return contextualization_weights
234
+ def change_sense13_weight(contextualization_weights, length, token_index, new_weight):
235
+ contextualization_weights[0, 13, length-1, token_index] = new_weight
236
+ return contextualization_weights
237
+ def change_sense14_weight(contextualization_weights, length, token_index, new_weight):
238
+ contextualization_weights[0, 14, length-1, token_index] = new_weight
239
+ return contextualization_weights
240
+ def change_sense15_weight(contextualization_weights, length, token_index, new_weight):
241
+ contextualization_weights[0, 15, length-1, token_index] = new_weight
242
+ return contextualization_weights
243
+
244
+ """
245
+ Clears all gr.State variables used to store info across methods when the input sentence changes.
246
+ """
247
+ def clear_states(contextualization_weights, token_index, length):
248
+ contextualization_weights = None
249
+ token_index = None
250
+ length = 0
251
+ return contextualization_weights, token_index, length
252
+
253
+ def reset_weights(contextualization_weights):
254
+ print("Resetting weights...")
255
+ contextualization_weights = None
256
+ return contextualization_weights
257
+
258
+ with gr.Blocks( theme = gr.themes.Base(),
259
+ css = """#sense0slider, #sense1slider, #sense2slider, #sense3slider, #sense4slider, #sense5slider, #sense6slider, #sense7slider,
260
+ #sense8slider, #sense9slider, #sense1slider0, #sense11slider, #sense12slider, #sense13slider, #sense14slider, #sense15slider
261
+ { height: 200px; width: 200px; transform: rotate(270deg); }"""
262
+ ) as demo:
263
+
264
+ gr.Markdown("""
265
+ ## Backpack Sense Visualization
266
+ """)
267
+
268
+ with gr.Tab("Language Modeling"):
269
+ contextualization_weights = gr.State(None) # store session data for sharing between functions
270
+ token_index = gr.State(None)
271
+ length = gr.State(0)
272
+ top_k = gr.State(10)
273
+ with gr.Row():
274
+ with gr.Column(scale=8):
275
+ input_sentence = gr.Textbox(label="Input Sentence", placeholder='Enter a sentence and click "Predict next word". Then, you can go to the Tokens section, click on a token, and see its contextualization weights.')
276
+ with gr.Column(scale=1):
277
+ predict = gr.Button(value="Predict next word", variant="primary")
278
+ reset_weights_button = gr.Button("Reset weights")
279
+ gr.Markdown("""#### Top-k predicted next word""")
280
+ top_k_words = gr.Dataframe(interactive=False)
281
+ gr.Markdown("""### **Token Breakdown:** click on a token below to see its senses and contextualization weights""")
282
+ tokens = gr.DataFrame()
283
+ with gr.Row():
284
+ with gr.Column(scale=1):
285
+ selected_token = gr.Textbox(label="Current Selected Token", interactive=False)
286
+ with gr.Column(scale=8):
287
+ gr.Markdown("""####
288
+ Once a token is chosen, you can **use the sliders below to change the weight of any sense or multiple senses** for that token, \
289
+ and then click "Predict next word" to see updated next-word predictions. Erase all changes with "Reset weights".
290
+ """)
291
+ # sense sliders and top sense words dataframes
292
+ with gr.Row():
293
+ with gr.Column(scale=0, min_width=120):
294
+ sense0slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 0", elem_id="sense0slider", interactive=True)
295
+ with gr.Column(scale=0, min_width=120):
296
+ sense1slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 1", elem_id="sense1slider", interactive=True)
297
+ with gr.Column(scale=0, min_width=120):
298
+ sense2slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 2", elem_id="sense2slider", interactive=True)
299
+ with gr.Column(scale=0, min_width=120):
300
+ sense3slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 3", elem_id="sense3slider", interactive=True)
301
+ with gr.Column(scale=0, min_width=120):
302
+ sense4slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 4", elem_id="sense4slider", interactive=True)
303
+ with gr.Column(scale=0, min_width=120):
304
+ sense5slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 5", elem_id="sense5slider", interactive=True)
305
+ with gr.Column(scale=0, min_width=120):
306
+ sense6slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 6", elem_id="sense6slider", interactive=True)
307
+ with gr.Column(scale=0, min_width=120):
308
+ sense7slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 7", elem_id="sense7slider", interactive=True)
309
+ with gr.Row():
310
+ with gr.Column(scale=0, min_width=120):
311
+ sense0words = gr.DataFrame(headers = ["Sense 0"])
312
+ with gr.Column(scale=0, min_width=120):
313
+ sense1words = gr.DataFrame(headers = ["Sense 1"])
314
+ with gr.Column(scale=0, min_width=120):
315
+ sense2words = gr.DataFrame(headers = ["Sense 2"])
316
+ with gr.Column(scale=0, min_width=120):
317
+ sense3words = gr.DataFrame(headers = ["Sense 3"])
318
+ with gr.Column(scale=0, min_width=120):
319
+ sense4words = gr.DataFrame(headers = ["Sense 4"])
320
+ with gr.Column(scale=0, min_width=120):
321
+ sense5words = gr.DataFrame(headers = ["Sense 5"])
322
+ with gr.Column(scale=0, min_width=120):
323
+ sense6words = gr.DataFrame(headers = ["Sense 6"])
324
+ with gr.Column(scale=0, min_width=120):
325
+ sense7words = gr.DataFrame(headers = ["Sense 7"])
326
+ with gr.Row():
327
+ with gr.Column(scale=0, min_width=120):
328
+ sense0negwords = gr.DataFrame(headers = ["Top Negative"])
329
+ with gr.Column(scale=0, min_width=120):
330
+ sense1negwords = gr.DataFrame(headers = ["Top Negative"])
331
+ with gr.Column(scale=0, min_width=120):
332
+ sense2negwords = gr.DataFrame(headers = ["Top Negative"])
333
+ with gr.Column(scale=0, min_width=120):
334
+ sense3negwords = gr.DataFrame(headers = ["Top Negative"])
335
+ with gr.Column(scale=0, min_width=120):
336
+ sense4negwords = gr.DataFrame(headers = ["Top Negative"])
337
+ with gr.Column(scale=0, min_width=120):
338
+ sense5negwords = gr.DataFrame(headers = ["Top Negative"])
339
+ with gr.Column(scale=0, min_width=120):
340
+ sense6negwords = gr.DataFrame(headers = ["Top Negative"])
341
+ with gr.Column(scale=0, min_width=120):
342
+ sense7negwords = gr.DataFrame(headers = ["Top Negative"])
343
+ with gr.Row():
344
+ with gr.Column(scale=0, min_width=120):
345
+ sense8slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 8", elem_id="sense8slider", interactive=True)
346
+ with gr.Column(scale=0, min_width=120):
347
+ sense9slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 9", elem_id="sense9slider", interactive=True)
348
+ with gr.Column(scale=0, min_width=120):
349
+ sense10slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 10", elem_id="sense1slider0", interactive=True)
350
+ with gr.Column(scale=0, min_width=120):
351
+ sense11slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 11", elem_id="sense11slider", interactive=True)
352
+ with gr.Column(scale=0, min_width=120):
353
+ sense12slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 12", elem_id="sense12slider", interactive=True)
354
+ with gr.Column(scale=0, min_width=120):
355
+ sense13slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 13", elem_id="sense13slider", interactive=True)
356
+ with gr.Column(scale=0, min_width=120):
357
+ sense14slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 14", elem_id="sense14slider", interactive=True)
358
+ with gr.Column(scale=0, min_width=120):
359
+ sense15slider= gr.Slider(minimum=0, maximum=1, value=0, step=0.01, label="Sense 15", elem_id="sense15slider", interactive=True)
360
+ with gr.Row():
361
+ with gr.Column(scale=0, min_width=120):
362
+ sense8words = gr.DataFrame(headers = ["Sense 8"])
363
+ with gr.Column(scale=0, min_width=120):
364
+ sense9words = gr.DataFrame(headers = ["Sense 9"])
365
+ with gr.Column(scale=0, min_width=120):
366
+ sense10words = gr.DataFrame(headers = ["Sense 10"])
367
+ with gr.Column(scale=0, min_width=120):
368
+ sense11words = gr.DataFrame(headers = ["Sense 11"])
369
+ with gr.Column(scale=0, min_width=120):
370
+ sense12words = gr.DataFrame(headers = ["Sense 12"])
371
+ with gr.Column(scale=0, min_width=120):
372
+ sense13words = gr.DataFrame(headers = ["Sense 13"])
373
+ with gr.Column(scale=0, min_width=120):
374
+ sense14words = gr.DataFrame(headers = ["Sense 14"])
375
+ with gr.Column(scale=0, min_width=120):
376
+ sense15words = gr.DataFrame(headers = ["Sense 15"])
377
+ with gr.Row():
378
+ with gr.Column(scale=0, min_width=120):
379
+ sense8negwords = gr.DataFrame(headers = ["Top Negative"])
380
+ with gr.Column(scale=0, min_width=120):
381
+ sense9negwords = gr.DataFrame(headers = ["Top Negative"])
382
+ with gr.Column(scale=0, min_width=120):
383
+ sense10negwords = gr.DataFrame(headers = ["Top Negative"])
384
+ with gr.Column(scale=0, min_width=120):
385
+ sense11negwords = gr.DataFrame(headers = ["Top Negative"])
386
+ with gr.Column(scale=0, min_width=120):
387
+ sense12negwords = gr.DataFrame(headers = ["Top Negative"])
388
+ with gr.Column(scale=0, min_width=120):
389
+ sense13negwords = gr.DataFrame(headers = ["Top Negative"])
390
+ with gr.Column(scale=0, min_width=120):
391
+ sense14negwords = gr.DataFrame(headers = ["Top Negative"])
392
+ with gr.Column(scale=0, min_width=120):
393
+ sense15negwords = gr.DataFrame(headers = ["Top Negative"])
394
+ gr.Markdown("""Note: **"Top Negative"** shows words that have the most negative dot products with the sense vector, which can exhibit more coherent meaning than those with the most positive dot products.
395
+ To see more representative words of each sense, scroll to the top and use the **"Individual Word Sense Look Up"** tab.""")
396
+ # gr.Examples(
397
+ # examples=[["Messi plays for", top_k, None]],
398
+ # inputs=[input_sentence, top_k, contextualization_weights],
399
+ # outputs=[tokens, top_k_words, length, contextualization_weights],
400
+ # fn=predict_next_word,
401
+ # )
402
+
403
+ sense0slider.change(fn=change_sense0_weight,
404
+ inputs=[contextualization_weights, length, token_index, sense0slider],
405
+ outputs=[contextualization_weights])
406
+ sense1slider.change(fn=change_sense1_weight,
407
+ inputs=[contextualization_weights, length, token_index, sense1slider],
408
+ outputs=[contextualization_weights])
409
+ sense2slider.change(fn=change_sense2_weight,
410
+ inputs=[contextualization_weights, length, token_index, sense2slider],
411
+ outputs=[contextualization_weights])
412
+ sense3slider.change(fn=change_sense3_weight,
413
+ inputs=[contextualization_weights, length, token_index, sense3slider],
414
+ outputs=[contextualization_weights])
415
+ sense4slider.change(fn=change_sense4_weight,
416
+ inputs=[contextualization_weights, length, token_index, sense4slider],
417
+ outputs=[contextualization_weights])
418
+ sense5slider.change(fn=change_sense5_weight,
419
+ inputs=[contextualization_weights, length, token_index, sense5slider],
420
+ outputs=[contextualization_weights])
421
+ sense6slider.change(fn=change_sense6_weight,
422
+ inputs=[contextualization_weights, length, token_index, sense6slider],
423
+ outputs=[contextualization_weights])
424
+ sense7slider.change(fn=change_sense7_weight,
425
+ inputs=[contextualization_weights, length, token_index, sense7slider],
426
+ outputs=[contextualization_weights])
427
+ sense8slider.change(fn=change_sense8_weight,
428
+ inputs=[contextualization_weights, length, token_index, sense8slider],
429
+ outputs=[contextualization_weights])
430
+ sense9slider.change(fn=change_sense9_weight,
431
+ inputs=[contextualization_weights, length, token_index, sense9slider],
432
+ outputs=[contextualization_weights])
433
+ sense10slider.change(fn=change_sense10_weight,
434
+ inputs=[contextualization_weights, length, token_index, sense10slider],
435
+ outputs=[contextualization_weights])
436
+ sense11slider.change(fn=change_sense11_weight,
437
+ inputs=[contextualization_weights, length, token_index, sense11slider],
438
+ outputs=[contextualization_weights])
439
+ sense12slider.change(fn=change_sense12_weight,
440
+ inputs=[contextualization_weights, length, token_index, sense12slider],
441
+ outputs=[contextualization_weights])
442
+ sense13slider.change(fn=change_sense13_weight,
443
+ inputs=[contextualization_weights, length, token_index, sense13slider],
444
+ outputs=[contextualization_weights])
445
+ sense14slider.change(fn=change_sense14_weight,
446
+ inputs=[contextualization_weights, length, token_index, sense14slider],
447
+ outputs=[contextualization_weights])
448
+ sense15slider.change(fn=change_sense15_weight,
449
+ inputs=[contextualization_weights, length, token_index, sense15slider],
450
+ outputs=[contextualization_weights])
451
+
452
+
453
+ predict.click(
454
+ fn=predict_next_word,
455
+ inputs = [input_sentence, top_k, contextualization_weights],
456
+ outputs= [tokens, top_k_words, length, contextualization_weights],
457
+ )
458
+
459
+ tokens.select(fn=new_token_contextual_weights,
460
+ inputs=[contextualization_weights, length],
461
+ outputs= [selected_token, token_index,
462
+
463
+ sense0words, sense1words, sense2words, sense3words, sense4words, sense5words, sense6words, sense7words,
464
+ sense8words, sense9words, sense10words, sense11words, sense12words, sense13words, sense14words, sense15words,
465
+
466
+ sense0negwords, sense1negwords, sense2negwords, sense3negwords, sense4negwords, sense5negwords, sense6negwords, sense7negwords,
467
+ sense8negwords, sense9negwords, sense10negwords, sense11negwords, sense12negwords, sense13negwords, sense14negwords, sense15negwords,
468
+
469
+ sense0slider, sense1slider, sense2slider, sense3slider, sense4slider, sense5slider, sense6slider, sense7slider,
470
+ sense8slider, sense9slider, sense10slider, sense11slider, sense12slider, sense13slider, sense14slider, sense15slider]
471
+ )
472
+
473
+ reset_weights_button.click(
474
+ fn=reset_weights,
475
+ inputs=[contextualization_weights],
476
+ outputs=[contextualization_weights]
477
+ ).success(
478
+ fn=predict_next_word,
479
+ inputs = [input_sentence, top_k, contextualization_weights],
480
+ outputs= [tokens, top_k_words, length, contextualization_weights],
481
+ ).success(
482
+ fn=get_token_contextual_weights,
483
+ inputs=[contextualization_weights, length, selected_token, token_index],
484
+ outputs= [selected_token, token_index,
485
+
486
+ sense0words, sense1words, sense2words, sense3words, sense4words, sense5words, sense6words, sense7words,
487
+ sense8words, sense9words, sense10words, sense11words, sense12words, sense13words, sense14words, sense15words,
488
+
489
+ sense0negwords, sense1negwords, sense2negwords, sense3negwords, sense4negwords, sense5negwords, sense6negwords, sense7negwords,
490
+ sense8negwords, sense9negwords, sense10negwords, sense11negwords, sense12negwords, sense13negwords, sense14negwords, sense15negwords,
491
+
492
+ sense0slider, sense1slider, sense2slider, sense3slider, sense4slider, sense5slider, sense6slider, sense7slider,
493
+ sense8slider, sense9slider, sense10slider, sense11slider, sense12slider, sense13slider, sense14slider, sense15slider]
494
+ )
495
+
496
+ input_sentence.change(
497
+ fn=clear_states,
498
+ inputs=[contextualization_weights, token_index, length],
499
+ outputs=[contextualization_weights, token_index, length]
500
+ )
501
+
502
+ with gr.Tab("Individual Word Sense Look Up"):
503
+ gr.Markdown("""> Note on tokenization: Backpack uses the GPT-2 tokenizer, which includes the space before a word as part \
504
+ of the token, so by default, a space character `' '` is added to the beginning of the word \
505
+ you look up. You can disable this by checking `Remove space before word`, but know this might \
506
+ cause strange behaviors like breaking `afraid` into `af` and `raid`, or `slight` into `s` and `light`.
507
+ """)
508
+ with gr.Row():
509
+ word = gr.Textbox(label="Word", placeholder="e.g. science")
510
+ token_breakdown = gr.Textbox(label="Token Breakdown (senses are for the first token only)")
511
+ remove_space = gr.Checkbox(label="Remove space before word", default=False)
512
+ count = gr.Slider(minimum=1, maximum=20, value=10, label="Top K", step=1)
513
+ look_up_button = gr.Button("Look up")
514
+ pos_outputs = gr.Dataframe(label="Highest Scoring Senses")
515
+ neg_outputs = gr.Dataframe(label="Lowest Scoring Senses")
516
+ gr.Examples(
517
+ examples=["science", "afraid", "book", "slight"],
518
+ inputs=[word],
519
+ outputs=[pos_outputs, neg_outputs, token_breakdown],
520
+ fn=visualize_word,
521
+ cache_examples=True,
522
+ )
523
+
524
+ look_up_button.click(
525
+ fn=visualize_word,
526
+ inputs= [word, count, remove_space],
527
+ outputs= [pos_outputs, neg_outputs, token_breakdown],
528
+ )
529
+
530
+ demo.launch()
531
+
532
+
533
+ # Code for generating slider functions & event listners
534
+
535
+ # for i in range(16):
536
+ # print(
537
+ # f"""def change_sense{i}_weight(contextualization_weights, length, token_index, new_weight):
538
+ # print(f"Changing weight for the {i}th sense of the {{token_index}}th token.")
539
+ # print("new_weight to be assigned = ", new_weight)
540
+ # contextualization_weights[0, {i}, length-1, token_index] = new_weight
541
+ # print("contextualization_weights: ", contextualization_weights[0, :, length-1, token_index])
542
+ # return contextualization_weights"""
543
+ # )
544
+
545
+ # for i in range(16):
546
+ # print(
547
+ # f""" sense{i}slider.change(fn=change_sense{i}_weight,
548
+ # inputs=[contextualization_weights, length, token_index, sense{i}slider],
549
+ # outputs=[contextualization_weights])"""
550
+ # )
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ pandas
3
+ transformers
4
+ gradio
senses/all_vecs_mtx.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f0c9de5688dd793470c40ebc3b49c29be6ddbf9a38804bca64512940671e129
3
+ size 2470232826
senses/lm_head.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f94054e64b4d1a07e18443769df4d3b9e346c00b02ffe4e9579e8313034dac24
3
+ size 154411755
senses/use_senses.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Visualize some sense vectors"""
2
+
3
+ import torch
4
+ import argparse
5
+
6
+ import transformers
7
+
8
+ def visualize_word(word, tokenizer, vecs, lm_head, count=20, contents=None):
9
+ """
10
+ Prints out the top-scoring words (and lowest-scoring words) for each sense.
11
+
12
+ """
13
+ if contents is None:
14
+ print(word)
15
+ token_id = tokenizer(word)['input_ids'][0]
16
+ contents = vecs[token_id] # torch.Size([16, 768])
17
+
18
+ for i in range(contents.shape[0]):
19
+ print('~~~~~~~~~~~~~~~~~~~~~~~{}~~~~~~~~~~~~~~~~~~~~~~~~'.format(i))
20
+ logits = contents[i,:] @ lm_head.t() # (vocab,) [768] @ [768, 50257] -> [50257]
21
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
22
+ print('~~~Positive~~~')
23
+ for j in range(count):
24
+ print(tokenizer.decode(sorted_indices[j]), '\t','{:.2f}'.format(sorted_logits[j].item()))
25
+ print('~~~Negative~~~')
26
+ for j in range(count):
27
+ print(tokenizer.decode(sorted_indices[-j-1]), '\t','{:.2f}'.format(sorted_logits[-j-1].item()))
28
+ return contents
29
+ print()
30
+ print()
31
+ print()
32
+
33
+ argp = argparse.ArgumentParser()
34
+ argp.add_argument('vecs_path')
35
+ argp.add_argument('lm_head_path')
36
+ args = argp.parse_args()
37
+
38
+ # Load tokenizer and parameters
39
+ tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2')
40
+ vecs = torch.load(args.vecs_path)
41
+ lm_head = torch.load(args.lm_head_path)
42
+
43
+ visualize_word(input('Enter a word:'), tokenizer, vecs, lm_head, count=5)
44
+