dar-tau commited on
Commit
9d7840a
1 Parent(s): 9b5c8c6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -6
app.py CHANGED
@@ -45,7 +45,7 @@ suggested_interpretation_prompts = ["Before responding, let me repeat the messag
45
  def initialize_gpu():
46
  pass
47
 
48
- def get_hidden_states(raw_original_prompt):
49
  original_prompt = original_prompt_template.format(prompt=raw_original_prompt)
50
  model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
51
  tokens = tokenizer.batch_decode(model_inputs.input_ids[0])
@@ -55,8 +55,7 @@ def get_hidden_states(raw_original_prompt):
55
  for i, token in enumerate(tokens):
56
  btn = gr.Button(token)
57
  btn.click(partial(run_interpretation, interpreted_vectors=hidden_states[:, i]),
58
- [interpretation_prompt, num_tokens, do_sample, temperature, top_k, top_p, repetition_penalty, length_penalty],
59
- [json_output])
60
  token_btns.append(btn)
61
  token_btns += [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))]
62
  return token_btns
@@ -143,11 +142,17 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
143
 
144
  with gr.Group('Interpretation'):
145
  interpretation_prompt = gr.Text(suggested_interpretation_prompts[0], label='Interpretation Prompt')
146
-
147
  with gr.Group('Output'):
148
  with gr.Row():
149
  tokens_container = [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS)]
150
  json_output = gr.JSON()
151
-
152
- original_prompt_btn.click(get_hidden_states, [original_prompt_raw], [*tokens_container])
 
 
 
 
 
 
153
  demo.launch()
 
45
  def initialize_gpu():
46
  pass
47
 
48
+ def get_hidden_states(raw_original_prompt, interpretation_args, interpretation_outputs):
49
  original_prompt = original_prompt_template.format(prompt=raw_original_prompt)
50
  model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
51
  tokens = tokenizer.batch_decode(model_inputs.input_ids[0])
 
55
  for i, token in enumerate(tokens):
56
  btn = gr.Button(token)
57
  btn.click(partial(run_interpretation, interpreted_vectors=hidden_states[:, i]),
58
+ interpretation_args, interpretation_outputs)
 
59
  token_btns.append(btn)
60
  token_btns += [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))]
61
  return token_btns
 
142
 
143
  with gr.Group('Interpretation'):
144
  interpretation_prompt = gr.Text(suggested_interpretation_prompts[0], label='Interpretation Prompt')
145
+
146
  with gr.Group('Output'):
147
  with gr.Row():
148
  tokens_container = [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS)]
149
  json_output = gr.JSON()
150
+
151
+ interpretation_args = [interpretation_prompt, num_tokens, do_sample, temperature,
152
+ top_k, top_p, repetition_penalty, length_penalty]
153
+ interpretation_outputs = [json_output]
154
+
155
+ original_prompt_btn.click(partial(get_hidden_states, interpretation_args=interpretation_args,
156
+ interpretation_outputs=interpretation_outputs
157
+ ), [original_prompt_raw], [*tokens_container])
158
  demo.launch()