dar-tau commited on
Commit
2fcc96e
1 Parent(s): 53d5a65

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -5
app.py CHANGED
@@ -56,9 +56,9 @@ def get_hidden_states(progress, raw_original_prompt):
56
  return [hidden_states, *token_btns]
57
 
58
 
59
- def run_interpretation(progress, global_state, raw_interpretation_prompt, max_new_tokens, do_sample,
60
  temperature, top_k, top_p, repetition_penalty, length_penalty, i,
61
- num_beams=1):
62
 
63
  interpreted_vectors = global_state[:, i]
64
  length_penalty = -length_penalty # unintuitively, length_penalty > 0 will make sequences longer, so we negate it
@@ -184,10 +184,9 @@ with gr.Blocks(theme=gr.themes.Default(), css=css) as demo:
184
  tokens_container.append(btn)
185
  interpretation_bubbles = [gr.Textbox('', container=False, visible=False, elem_classes=['bubble'])
186
  for i in range(model.config.num_hidden_layers)]
187
- progress = gr.Progress()
188
  for i, btn in enumerate(tokens_container):
189
- btn.click(partial(run_interpretation, i=i), [progress,
190
- global_state, interpretation_prompt, num_tokens, do_sample, temperature,
191
  top_k, top_p, repetition_penalty, length_penalty
192
  ], [*interpretation_bubbles])
193
 
 
56
  return [hidden_states, *token_btns]
57
 
58
 
59
+ def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens, do_sample,
60
  temperature, top_k, top_p, repetition_penalty, length_penalty, i,
61
+ num_beams=1, progress=gr.Progress()):
62
 
63
  interpreted_vectors = global_state[:, i]
64
  length_penalty = -length_penalty # unintuitively, length_penalty > 0 will make sequences longer, so we negate it
 
184
  tokens_container.append(btn)
185
  interpretation_bubbles = [gr.Textbox('', container=False, visible=False, elem_classes=['bubble'])
186
  for i in range(model.config.num_hidden_layers)]
187
+
188
  for i, btn in enumerate(tokens_container):
189
+ btn.click(partial(run_interpretation, i=i), [global_state, interpretation_prompt, num_tokens, do_sample, temperature,
 
190
  top_k, top_p, repetition_penalty, length_penalty
191
  ], [*interpretation_bubbles])
192