dar-tau commited on
Commit
bc2f9ff
1 Parent(s): ce4a0c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -14
app.py CHANGED
@@ -11,6 +11,7 @@ from interpret import InterpretationPrompt
11
 
12
  MAX_PROMPT_TOKENS = 60
13
 
 
14
  ## info
15
  dataset_info = [
16
  {'name': 'Commonsense', 'hf_repo': 'tau/commonsense_qa', 'text_col': 'question'},
@@ -56,7 +57,7 @@ suggested_interpretation_prompts = ["Sure, I'll summarize your message:", "Sure,
56
  def initialize_gpu():
57
  pass
58
 
59
- def get_hidden_states(raw_original_prompt, progress=gr.Progress()):
60
  original_prompt = original_prompt_template.format(prompt=raw_original_prompt)
61
  model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
62
  tokens = tokenizer.batch_decode(model_inputs.input_ids[0])
@@ -65,17 +66,13 @@ def get_hidden_states(raw_original_prompt, progress=gr.Progress()):
65
  token_btns = ([gr.Button(token, visible=True) for token in tokens]
66
  + [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))])
67
  progress_dummy_output = ''
68
- return [progress_dummy_output, hidden_states, *token_btns]
69
-
70
-
71
- @spaces.GPU
72
- def generate_interpretation_gpu(interpret_prompt, *args, **kwargs):
73
- return interpret_prompt.generate(*args, **kwargs)
74
 
75
 
76
  @spaces.GPU
77
  def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens, do_sample,
78
- temperature, top_k, top_p, repetition_penalty, length_penalty, use_gpu, i,
79
  num_beams=1):
80
 
81
  interpreted_vectors = global_state[:, i]
@@ -98,8 +95,8 @@ def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens,
98
  interpretation_prompt = InterpretationPrompt(tokenizer, interpretation_prompt)
99
 
100
  # generate the interpretations
101
- generate = generate_interpretation_gpu if use_gpu else lambda interpretation_prompt, *args, **kwargs: interpretation_prompt.generate(*args, **kwargs)
102
- generated = generate(interpretation_prompt, model, {0: interpreted_vectors}, k=3, **generation_kwargs)
103
  generation_texts = tokenizer.batch_decode(generated)
104
  progress_dummy_output = ''
105
  return ([progress_dummy_output] +
@@ -187,14 +184,14 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
187
 
188
  with gr.Group():
189
  original_prompt_raw.render()
190
- original_prompt_btn = gr.Button('Compute', variant='primary')
191
 
192
  tokens_container = []
193
  with gr.Row():
 
194
  for i in range(MAX_PROMPT_TOKENS):
195
  btn = gr.Button('', visible=False, elem_classes=['token_btn'])
196
  tokens_container.append(btn)
197
- use_gpu = False # gr.Checkbox(value=False, label='Use GPU')
198
  progress_dummy = gr.Markdown('', elem_id='progress_dummy')
199
 
200
 
@@ -226,12 +223,12 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
226
 
227
  # event listeners
228
  for i, btn in enumerate(tokens_container):
229
- btn.click(partial(run_interpretation, i=i, use_gpu=use_gpu), [global_state, interpretation_prompt,
230
  num_tokens, do_sample, temperature,
231
  top_k, top_p, repetition_penalty, length_penalty,
232
  ], [progress_dummy, *interpretation_bubbles])
233
 
234
  original_prompt_btn.click(get_hidden_states,
235
  [original_prompt_raw],
236
- [progress_dummy, global_state, *tokens_container])
237
  demo.launch()
 
11
 
12
  MAX_PROMPT_TOKENS = 60
13
 
14
+
15
  ## info
16
  dataset_info = [
17
  {'name': 'Commonsense', 'hf_repo': 'tau/commonsense_qa', 'text_col': 'question'},
 
57
  def initialize_gpu():
58
  pass
59
 
60
+ def get_hidden_states(raw_original_prompt):
61
  original_prompt = original_prompt_template.format(prompt=raw_original_prompt)
62
  model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
63
  tokens = tokenizer.batch_decode(model_inputs.input_ids[0])
 
66
  token_btns = ([gr.Button(token, visible=True) for token in tokens]
67
  + [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))])
68
  progress_dummy_output = ''
69
+ invisible_bubbles = [gr.Textbox('', visible=False) for i in range(MAX_PROMPT_TOKENS)]
70
+ return [progress_dummy_output, hidden_states, *token_btns, *invisible_bubbles]
 
 
 
 
71
 
72
 
73
  @spaces.GPU
74
  def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens, do_sample,
75
+ temperature, top_k, top_p, repetition_penalty, length_penalty, i,
76
  num_beams=1):
77
 
78
  interpreted_vectors = global_state[:, i]
 
95
  interpretation_prompt = InterpretationPrompt(tokenizer, interpretation_prompt)
96
 
97
  # generate the interpretations
98
+ # generate = generate_interpretation_gpu if use_gpu else lambda interpretation_prompt, *args, **kwargs: interpretation_prompt.generate(*args, **kwargs)
99
+ generated = interpretation_prompt.generate(model, {0: interpreted_vectors}, k=3, **generation_kwargs)
100
  generation_texts = tokenizer.batch_decode(generated)
101
  progress_dummy_output = ''
102
  return ([progress_dummy_output] +
 
184
 
185
  with gr.Group():
186
  original_prompt_raw.render()
187
+ original_prompt_btn = gr.Button('Output Token List', variant='primary')
188
 
189
  tokens_container = []
190
  with gr.Row():
191
+ gr.Markdown('### Here go the tokens of the prompt (click on the one to explore)')
192
  for i in range(MAX_PROMPT_TOKENS):
193
  btn = gr.Button('', visible=False, elem_classes=['token_btn'])
194
  tokens_container.append(btn)
 
195
  progress_dummy = gr.Markdown('', elem_id='progress_dummy')
196
 
197
 
 
223
 
224
  # event listeners
225
  for i, btn in enumerate(tokens_container):
226
+ btn.click(partial(run_interpretation, i=i), [global_state, interpretation_prompt,
227
  num_tokens, do_sample, temperature,
228
  top_k, top_p, repetition_penalty, length_penalty,
229
  ], [progress_dummy, *interpretation_bubbles])
230
 
231
  original_prompt_btn.click(get_hidden_states,
232
  [original_prompt_raw],
233
+ [progress_dummy, global_state, *tokens_container, *interpretation_bubbles])
234
  demo.launch()