Update app.py
Browse files
app.py
CHANGED
@@ -45,7 +45,7 @@ suggested_interpretation_prompts = ["Before responding, let me repeat the messag
|
|
45 |
def initialize_gpu():
|
46 |
pass
|
47 |
|
48 |
-
def get_hidden_states(raw_original_prompt):
|
49 |
original_prompt = original_prompt_template.format(prompt=raw_original_prompt)
|
50 |
model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
|
51 |
tokens = tokenizer.batch_decode(model_inputs.input_ids[0])
|
@@ -55,8 +55,7 @@ def get_hidden_states(raw_original_prompt):
|
|
55 |
for i, token in enumerate(tokens):
|
56 |
btn = gr.Button(token)
|
57 |
btn.click(partial(run_interpretation, interpreted_vectors=hidden_states[:, i]),
|
58 |
-
|
59 |
-
[json_output])
|
60 |
token_btns.append(btn)
|
61 |
token_btns += [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))]
|
62 |
return token_btns
|
@@ -143,11 +142,17 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
|
|
143 |
|
144 |
with gr.Group('Interpretation'):
|
145 |
interpretation_prompt = gr.Text(suggested_interpretation_prompts[0], label='Interpretation Prompt')
|
146 |
-
|
147 |
with gr.Group('Output'):
|
148 |
with gr.Row():
|
149 |
tokens_container = [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS)]
|
150 |
json_output = gr.JSON()
|
151 |
-
|
152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
demo.launch()
|
|
|
45 |
def initialize_gpu():
|
46 |
pass
|
47 |
|
48 |
+
def get_hidden_states(raw_original_prompt, interpretation_args, interpretation_outputs):
|
49 |
original_prompt = original_prompt_template.format(prompt=raw_original_prompt)
|
50 |
model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
|
51 |
tokens = tokenizer.batch_decode(model_inputs.input_ids[0])
|
|
|
55 |
for i, token in enumerate(tokens):
|
56 |
btn = gr.Button(token)
|
57 |
btn.click(partial(run_interpretation, interpreted_vectors=hidden_states[:, i]),
|
58 |
+
interpretation_args, interpretation_outputs)
|
|
|
59 |
token_btns.append(btn)
|
60 |
token_btns += [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))]
|
61 |
return token_btns
|
|
|
142 |
|
143 |
with gr.Group('Interpretation'):
|
144 |
interpretation_prompt = gr.Text(suggested_interpretation_prompts[0], label='Interpretation Prompt')
|
145 |
+
|
146 |
with gr.Group('Output'):
|
147 |
with gr.Row():
|
148 |
tokens_container = [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS)]
|
149 |
json_output = gr.JSON()
|
150 |
+
|
151 |
+
interpretation_args = [interpretation_prompt, num_tokens, do_sample, temperature,
|
152 |
+
top_k, top_p, repetition_penalty, length_penalty]
|
153 |
+
interpretation_outputs = [json_output]
|
154 |
+
|
155 |
+
original_prompt_btn.click(partial(get_hidden_states, interpretation_args=interpretation_args,
|
156 |
+
interpretation_outputs=interpretation_outputs
|
157 |
+
), [original_prompt_raw], [*tokens_container])
|
158 |
demo.launch()
|