dar-tau commited on
Commit
765296c
1 Parent(s): 34b25c9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -7
app.py CHANGED
@@ -82,7 +82,7 @@ def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens,
82
  # generate the interpretations
83
  generated = interpretation_prompt.generate(model, {0: interpreted_vectors}, k=3, **generation_kwargs)
84
  generation_texts = tokenizer.batch_decode(generated)
85
- return generation_texts
86
 
87
 
88
  ## main
@@ -103,7 +103,14 @@ tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, token=os.environ['hf_t
103
 
104
  # demo
105
  json_output = gr.JSON()
106
- css = ''
 
 
 
 
 
 
 
107
 
108
  # '''
109
  # .token_btn{
@@ -165,14 +172,21 @@ with gr.Blocks(theme=gr.themes.Default(), css=css) as demo:
165
 
166
  with gr.Group('Output'):
167
  tokens_container = []
 
168
  with gr.Row():
169
  for i in range(MAX_PROMPT_TOKENS):
170
  btn = gr.Button('', visible=False, elem_classes=['token_btn'])
171
- btn.click(partial(run_interpretation, i=i), [global_state, interpretation_prompt, num_tokens, do_sample, temperature,
172
- top_k, top_p, repetition_penalty, length_penalty
173
- ], [json_output])
174
  tokens_container.append(btn)
175
- json_output.render()
176
 
177
- original_prompt_btn.click(get_hidden_states, [original_prompt_raw], [global_state, *tokens_container])
 
 
 
 
 
 
 
 
 
 
178
  demo.launch()
 
82
  # generate the interpretations
83
  generated = interpretation_prompt.generate(model, {0: interpreted_vectors}, k=3, **generation_kwargs)
84
  generation_texts = tokenizer.batch_decode(generated)
85
+ return [gr.Text(text, visible=True) for text in generation_texts]
86
 
87
 
88
  ## main
 
103
 
104
  # demo
105
  json_output = gr.JSON()
106
+ css = '''
107
+ .bubble {
108
+ border: 2px solid #000;
109
+ border-radius: 10px;
110
+ padding: 10px;
111
+ }
112
+ '''
113
+
114
 
115
  # '''
116
  # .token_btn{
 
172
 
173
  with gr.Group('Output'):
174
  tokens_container = []
175
+ interpretation_bubbles = []
176
  with gr.Row():
177
  for i in range(MAX_PROMPT_TOKENS):
178
  btn = gr.Button('', visible=False, elem_classes=['token_btn'])
 
 
 
179
  tokens_container.append(btn)
 
180
 
181
+ for i in range(model.config.num_hidden_layers):
182
+ interpretation_bubbles.append(gr.Text('', visible=False, elrm_classes=['bubble']))
183
+
184
+ for i, btn in enumerate(tokens_container):
185
+ btn.click(partial(run_interpretation, i=i), [global_state, interpretation_prompt, num_tokens, do_sample, temperature,
186
+ top_k, top_p, repetition_penalty, length_penalty
187
+ ], [*interpretation_bubbles])
188
+
189
+ original_prompt_btn.click(get_hidden_states,
190
+ [original_prompt_raw],
191
+ [global_state, *tokens_container])
192
  demo.launch()