Update app.py
Browse files
app.py
CHANGED
@@ -56,9 +56,9 @@ def get_hidden_states(progress, raw_original_prompt):
|
|
56 |
return [hidden_states, *token_btns]
|
57 |
|
58 |
|
59 |
-
def run_interpretation(
|
60 |
temperature, top_k, top_p, repetition_penalty, length_penalty, i,
|
61 |
-
num_beams=1):
|
62 |
|
63 |
interpreted_vectors = global_state[:, i]
|
64 |
length_penalty = -length_penalty # unintuitively, length_penalty > 0 will make sequences longer, so we negate it
|
@@ -184,10 +184,9 @@ with gr.Blocks(theme=gr.themes.Default(), css=css) as demo:
|
|
184 |
tokens_container.append(btn)
|
185 |
interpretation_bubbles = [gr.Textbox('', container=False, visible=False, elem_classes=['bubble'])
|
186 |
for i in range(model.config.num_hidden_layers)]
|
187 |
-
|
188 |
for i, btn in enumerate(tokens_container):
|
189 |
-
btn.click(partial(run_interpretation, i=i), [
|
190 |
-
global_state, interpretation_prompt, num_tokens, do_sample, temperature,
|
191 |
top_k, top_p, repetition_penalty, length_penalty
|
192 |
], [*interpretation_bubbles])
|
193 |
|
|
|
56 |
return [hidden_states, *token_btns]
|
57 |
|
58 |
|
59 |
+
def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens, do_sample,
|
60 |
temperature, top_k, top_p, repetition_penalty, length_penalty, i,
|
61 |
+
num_beams=1, progress=gr.Progress()):
|
62 |
|
63 |
interpreted_vectors = global_state[:, i]
|
64 |
length_penalty = -length_penalty # unintuitively, length_penalty > 0 will make sequences longer, so we negate it
|
|
|
184 |
tokens_container.append(btn)
|
185 |
interpretation_bubbles = [gr.Textbox('', container=False, visible=False, elem_classes=['bubble'])
|
186 |
for i in range(model.config.num_hidden_layers)]
|
187 |
+
|
188 |
for i, btn in enumerate(tokens_container):
|
189 |
+
btn.click(partial(run_interpretation, i=i), [global_state, interpretation_prompt, num_tokens, do_sample, temperature,
|
|
|
190 |
top_k, top_p, repetition_penalty, length_penalty
|
191 |
], [*interpretation_bubbles])
|
192 |
|