Update app.py
Browse files
app.py
CHANGED
@@ -11,6 +11,7 @@ from interpret import InterpretationPrompt
|
|
11 |
|
12 |
MAX_PROMPT_TOKENS = 60
|
13 |
|
|
|
14 |
## info
|
15 |
dataset_info = [
|
16 |
{'name': 'Commonsense', 'hf_repo': 'tau/commonsense_qa', 'text_col': 'question'},
|
@@ -56,7 +57,7 @@ suggested_interpretation_prompts = ["Sure, I'll summarize your message:", "Sure,
|
|
56 |
def initialize_gpu():
|
57 |
pass
|
58 |
|
59 |
-
def get_hidden_states(raw_original_prompt
|
60 |
original_prompt = original_prompt_template.format(prompt=raw_original_prompt)
|
61 |
model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
|
62 |
tokens = tokenizer.batch_decode(model_inputs.input_ids[0])
|
@@ -65,17 +66,13 @@ def get_hidden_states(raw_original_prompt, progress=gr.Progress()):
|
|
65 |
token_btns = ([gr.Button(token, visible=True) for token in tokens]
|
66 |
+ [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))])
|
67 |
progress_dummy_output = ''
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
@spaces.GPU
|
72 |
-
def generate_interpretation_gpu(interpret_prompt, *args, **kwargs):
|
73 |
-
return interpret_prompt.generate(*args, **kwargs)
|
74 |
|
75 |
|
76 |
@spaces.GPU
|
77 |
def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens, do_sample,
|
78 |
-
temperature, top_k, top_p, repetition_penalty, length_penalty,
|
79 |
num_beams=1):
|
80 |
|
81 |
interpreted_vectors = global_state[:, i]
|
@@ -98,8 +95,8 @@ def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens,
|
|
98 |
interpretation_prompt = InterpretationPrompt(tokenizer, interpretation_prompt)
|
99 |
|
100 |
# generate the interpretations
|
101 |
-
generate = generate_interpretation_gpu if use_gpu else lambda interpretation_prompt, *args, **kwargs: interpretation_prompt.generate(*args, **kwargs)
|
102 |
-
generated = generate(
|
103 |
generation_texts = tokenizer.batch_decode(generated)
|
104 |
progress_dummy_output = ''
|
105 |
return ([progress_dummy_output] +
|
@@ -187,14 +184,14 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
|
|
187 |
|
188 |
with gr.Group():
|
189 |
original_prompt_raw.render()
|
190 |
-
original_prompt_btn = gr.Button('
|
191 |
|
192 |
tokens_container = []
|
193 |
with gr.Row():
|
|
|
194 |
for i in range(MAX_PROMPT_TOKENS):
|
195 |
btn = gr.Button('', visible=False, elem_classes=['token_btn'])
|
196 |
tokens_container.append(btn)
|
197 |
-
use_gpu = False # gr.Checkbox(value=False, label='Use GPU')
|
198 |
progress_dummy = gr.Markdown('', elem_id='progress_dummy')
|
199 |
|
200 |
|
@@ -226,12 +223,12 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
|
|
226 |
|
227 |
# event listeners
|
228 |
for i, btn in enumerate(tokens_container):
|
229 |
-
btn.click(partial(run_interpretation, i=i
|
230 |
num_tokens, do_sample, temperature,
|
231 |
top_k, top_p, repetition_penalty, length_penalty,
|
232 |
], [progress_dummy, *interpretation_bubbles])
|
233 |
|
234 |
original_prompt_btn.click(get_hidden_states,
|
235 |
[original_prompt_raw],
|
236 |
-
[progress_dummy, global_state, *tokens_container])
|
237 |
demo.launch()
|
|
|
11 |
|
12 |
MAX_PROMPT_TOKENS = 60
|
13 |
|
14 |
+
|
15 |
## info
|
16 |
dataset_info = [
|
17 |
{'name': 'Commonsense', 'hf_repo': 'tau/commonsense_qa', 'text_col': 'question'},
|
|
|
57 |
def initialize_gpu():
|
58 |
pass
|
59 |
|
60 |
+
def get_hidden_states(raw_original_prompt):
|
61 |
original_prompt = original_prompt_template.format(prompt=raw_original_prompt)
|
62 |
model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
|
63 |
tokens = tokenizer.batch_decode(model_inputs.input_ids[0])
|
|
|
66 |
token_btns = ([gr.Button(token, visible=True) for token in tokens]
|
67 |
+ [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))])
|
68 |
progress_dummy_output = ''
|
69 |
+
invisible_bubbles = [gr.Textbox('', visible=False) for i in range(MAX_PROMPT_TOKENS)]
|
70 |
+
return [progress_dummy_output, hidden_states, *token_btns, *invisible_bubbles]
|
|
|
|
|
|
|
|
|
71 |
|
72 |
|
73 |
@spaces.GPU
|
74 |
def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens, do_sample,
|
75 |
+
temperature, top_k, top_p, repetition_penalty, length_penalty, i,
|
76 |
num_beams=1):
|
77 |
|
78 |
interpreted_vectors = global_state[:, i]
|
|
|
95 |
interpretation_prompt = InterpretationPrompt(tokenizer, interpretation_prompt)
|
96 |
|
97 |
# generate the interpretations
|
98 |
+
# generate = generate_interpretation_gpu if use_gpu else lambda interpretation_prompt, *args, **kwargs: interpretation_prompt.generate(*args, **kwargs)
|
99 |
+
generated = interpretation_prompt.generate(model, {0: interpreted_vectors}, k=3, **generation_kwargs)
|
100 |
generation_texts = tokenizer.batch_decode(generated)
|
101 |
progress_dummy_output = ''
|
102 |
return ([progress_dummy_output] +
|
|
|
184 |
|
185 |
with gr.Group():
|
186 |
original_prompt_raw.render()
|
187 |
+
original_prompt_btn = gr.Button('Output Token List', variant='primary')
|
188 |
|
189 |
tokens_container = []
|
190 |
with gr.Row():
|
191 |
+
gr.Markdown('### Here go the tokens of the prompt (click on the one to explore)')
|
192 |
for i in range(MAX_PROMPT_TOKENS):
|
193 |
btn = gr.Button('', visible=False, elem_classes=['token_btn'])
|
194 |
tokens_container.append(btn)
|
|
|
195 |
progress_dummy = gr.Markdown('', elem_id='progress_dummy')
|
196 |
|
197 |
|
|
|
223 |
|
224 |
# event listeners
|
225 |
for i, btn in enumerate(tokens_container):
|
226 |
+
btn.click(partial(run_interpretation, i=i), [global_state, interpretation_prompt,
|
227 |
num_tokens, do_sample, temperature,
|
228 |
top_k, top_p, repetition_penalty, length_penalty,
|
229 |
], [progress_dummy, *interpretation_bubbles])
|
230 |
|
231 |
original_prompt_btn.click(get_hidden_states,
|
232 |
[original_prompt_raw],
|
233 |
+
[progress_dummy, global_state, *tokens_container, *interpretation_bubbles])
|
234 |
demo.launch()
|