dar-tau commited on
Commit
3e684af
1 Parent(s): d75586b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -68,9 +68,10 @@ def reset_model(model_name, *extra_components, with_extra_components=True):
68
  + [gr.Textbox('', visible=False) for _ in range(len(interpretation_bubbles))]
69
  + [gr.Button('', visible=False) for _ in range(len(tokens_container))]
70
  + [*extra_components])
71
-
 
72
  @spaces.GPU
73
- def get_hidden_states(raw_original_prompt):
74
  model, tokenizer = global_state.model, global_state.tokenizer
75
  original_prompt = global_state.original_prompt_template.format(prompt=raw_original_prompt)
76
  model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
@@ -222,7 +223,7 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
222
  ], [progress_dummy, *interpretation_bubbles])
223
 
224
  original_prompt_btn.click(get_hidden_states,
225
- [original_prompt_raw],
226
  [progress_dummy, *tokens_container, *interpretation_bubbles])
227
  original_prompt_raw.change(lambda: [gr.Button(visible=False) for _ in range(MAX_PROMPT_TOKENS)], [], tokens_container)
228
 
 
68
  + [gr.Textbox('', visible=False) for _ in range(len(interpretation_bubbles))]
69
  + [gr.Button('', visible=False) for _ in range(len(tokens_container))]
70
  + [*extra_components])
71
+
72
+
73
  @spaces.GPU
74
+ def get_hidden_states(global_state, raw_original_prompt):
75
  model, tokenizer = global_state.model, global_state.tokenizer
76
  original_prompt = global_state.original_prompt_template.format(prompt=raw_original_prompt)
77
  model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
 
223
  ], [progress_dummy, *interpretation_bubbles])
224
 
225
  original_prompt_btn.click(get_hidden_states,
226
+ [gr.State(global_state), original_prompt_raw],
227
  [progress_dummy, *tokens_container, *interpretation_bubbles])
228
  original_prompt_raw.change(lambda: [gr.Button(visible=False) for _ in range(MAX_PROMPT_TOKENS)], [], tokens_container)
229