Update app.py
Browse files
app.py
CHANGED
@@ -82,7 +82,7 @@ def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens,
|
|
82 |
# generate the interpretations
|
83 |
generated = interpretation_prompt.generate(model, {0: interpreted_vectors}, k=3, **generation_kwargs)
|
84 |
generation_texts = tokenizer.batch_decode(generated)
|
85 |
-
return generation_texts
|
86 |
|
87 |
|
88 |
## main
|
@@ -103,7 +103,14 @@ tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, token=os.environ['hf_t
|
|
103 |
|
104 |
# demo
|
105 |
json_output = gr.JSON()
|
106 |
-
css = ''
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
|
108 |
# '''
|
109 |
# .token_btn{
|
@@ -165,14 +172,21 @@ with gr.Blocks(theme=gr.themes.Default(), css=css) as demo:
|
|
165 |
|
166 |
with gr.Group('Output'):
|
167 |
tokens_container = []
|
|
|
168 |
with gr.Row():
|
169 |
for i in range(MAX_PROMPT_TOKENS):
|
170 |
btn = gr.Button('', visible=False, elem_classes=['token_btn'])
|
171 |
-
btn.click(partial(run_interpretation, i=i), [global_state, interpretation_prompt, num_tokens, do_sample, temperature,
|
172 |
-
top_k, top_p, repetition_penalty, length_penalty
|
173 |
-
], [json_output])
|
174 |
tokens_container.append(btn)
|
175 |
-
json_output.render()
|
176 |
|
177 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
demo.launch()
|
|
|
82 |
# generate the interpretations
|
83 |
generated = interpretation_prompt.generate(model, {0: interpreted_vectors}, k=3, **generation_kwargs)
|
84 |
generation_texts = tokenizer.batch_decode(generated)
|
85 |
+
return [gr.Text(text, visible=True) for text in generation_texts]
|
86 |
|
87 |
|
88 |
## main
|
|
|
103 |
|
104 |
# demo
|
105 |
json_output = gr.JSON()
|
106 |
+
css = '''
|
107 |
+
.bubble {
|
108 |
+
border: 2px solid #000;
|
109 |
+
border-radius: 10px;
|
110 |
+
padding: 10px;
|
111 |
+
}
|
112 |
+
'''
|
113 |
+
|
114 |
|
115 |
# '''
|
116 |
# .token_btn{
|
|
|
172 |
|
173 |
with gr.Group('Output'):
|
174 |
tokens_container = []
|
175 |
+
interpretation_bubbles = []
|
176 |
with gr.Row():
|
177 |
for i in range(MAX_PROMPT_TOKENS):
|
178 |
btn = gr.Button('', visible=False, elem_classes=['token_btn'])
|
|
|
|
|
|
|
179 |
tokens_container.append(btn)
|
|
|
180 |
|
181 |
+
for i in range(model.config.num_hidden_layers):
|
182 |
+
interpretation_bubbles.append(gr.Text('', visible=False, elrm_classes=['bubble']))
|
183 |
+
|
184 |
+
for i, btn in enumerate(tokens_container):
|
185 |
+
btn.click(partial(run_interpretation, i=i), [global_state, interpretation_prompt, num_tokens, do_sample, temperature,
|
186 |
+
top_k, top_p, repetition_penalty, length_penalty
|
187 |
+
], [*interpretation_bubbles])
|
188 |
+
|
189 |
+
original_prompt_btn.click(get_hidden_states,
|
190 |
+
[original_prompt_raw],
|
191 |
+
[global_state, *tokens_container])
|
192 |
demo.launch()
|