Update app.py
Browse files
app.py
CHANGED
@@ -50,14 +50,19 @@ def get_hidden_states(raw_original_prompt):
|
|
50 |
tokens = tokenizer.batch_decode(model_inputs.input_ids[0])
|
51 |
outputs = model(**model_inputs, output_hidden_states=True, return_dict=True)
|
52 |
hidden_states = torch.stack([h.squeeze(0).cpu().detach() for h in outputs.hidden_states], dim=0)
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
length_penalty = -length_penalty # unintuitively, length_penalty > 0 will make sequences longer, so we negate it
|
63 |
|
@@ -77,16 +82,9 @@ def run_model(raw_original_prompt, raw_interpretation_prompt, max_new_tokens, do
|
|
77 |
interpretation_prompt = interpretation_prompt_template.format(prompt=raw_interpretation_prompt)
|
78 |
interpretation_prompt = InterpretationPrompt(tokenizer, interpretation_prompt)
|
79 |
|
80 |
-
# compute the hidden stated from the original prompt (after putting it in the right template)
|
81 |
-
original_prompt = original_prompt_template.format(prompt=raw_original_prompt)
|
82 |
-
model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
|
83 |
-
outputs = model(**model_inputs, output_hidden_states=True, return_dict=True)
|
84 |
-
hidden_states = torch.stack([h.squeeze(0).cpu().detach() for h in outputs.hidden_states], dim=0)
|
85 |
-
|
86 |
# generate the interpretations
|
87 |
generated = interpretation_prompt.generate(model, {0: hidden_states[:, -1]}, k=3, **generation_kwargs)
|
88 |
generation_texts = tokenizer.batch_decode(generated)
|
89 |
-
# tokens = [x.lstrip('▁') for x in tokenizer.tokenize(text)]
|
90 |
return generation_texts
|
91 |
|
92 |
|
@@ -148,13 +146,7 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
|
|
148 |
with gr.Group('Output'):
|
149 |
with gr.Row():
|
150 |
tokens_container = [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS)]
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
original_prompt_btn.click(get_hidden_states, [original_prompt_raw], [*tokens_container])
|
155 |
-
# btn.click(run_model,
|
156 |
-
# [text, interpretation_prompt, num_tokens, do_sample, temperature,
|
157 |
-
# top_k, top_p, repetition_penalty, length_penalty],
|
158 |
-
# [tokens_container])
|
159 |
-
|
160 |
demo.launch()
|
|
|
50 |
tokens = tokenizer.batch_decode(model_inputs.input_ids[0])
|
51 |
outputs = model(**model_inputs, output_hidden_states=True, return_dict=True)
|
52 |
hidden_states = torch.stack([h.squeeze(0).cpu().detach() for h in outputs.hidden_states], dim=0)
|
53 |
+
token_btns = []
|
54 |
+
for i, token in enumerate(tokens):
|
55 |
+
btn = gr.Button(token)
|
56 |
+
btn.click(partial(run_interpretation, interpreted_vectors=hidden_states[:, i]),
|
57 |
+
[interpretation_prompt, num_tokens, do_sample, temperature, top_k, top_p, repetition_penalty, length_penalty],
|
58 |
+
[json_output])
|
59 |
+
token_btns.append(btn)
|
60 |
+
token_btns += [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))]
|
61 |
+
return token_btns
|
62 |
+
|
63 |
+
|
64 |
+
def run_interpretation(raw_interpretation_prompt, max_new_tokens, do_sample,
|
65 |
+
temperature, top_k, top_p, repetition_penalty, length_penalty, interpreted_vectors, num_beams=1):
|
66 |
|
67 |
length_penalty = -length_penalty # unintuitively, length_penalty > 0 will make sequences longer, so we negate it
|
68 |
|
|
|
82 |
interpretation_prompt = interpretation_prompt_template.format(prompt=raw_interpretation_prompt)
|
83 |
interpretation_prompt = InterpretationPrompt(tokenizer, interpretation_prompt)
|
84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
# generate the interpretations
|
86 |
generated = interpretation_prompt.generate(model, {0: hidden_states[:, -1]}, k=3, **generation_kwargs)
|
87 |
generation_texts = tokenizer.batch_decode(generated)
|
|
|
88 |
return generation_texts
|
89 |
|
90 |
|
|
|
146 |
with gr.Group('Output'):
|
147 |
with gr.Row():
|
148 |
tokens_container = [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS)]
|
149 |
+
json_output = gr.JSON()
|
150 |
+
|
151 |
+
original_prompt_btn.click(get_hidden_states, [original_prompt_raw], [*tokens_container])
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
demo.launch()
|