Update app.py
Browse files
app.py
CHANGED
@@ -68,9 +68,10 @@ def reset_model(model_name, *extra_components, with_extra_components=True):
|
|
68 |
+ [gr.Textbox('', visible=False) for _ in range(len(interpretation_bubbles))]
|
69 |
+ [gr.Button('', visible=False) for _ in range(len(tokens_container))]
|
70 |
+ [*extra_components])
|
71 |
-
|
|
|
72 |
@spaces.GPU
|
73 |
-
def get_hidden_states(raw_original_prompt):
|
74 |
model, tokenizer = global_state.model, global_state.tokenizer
|
75 |
original_prompt = global_state.original_prompt_template.format(prompt=raw_original_prompt)
|
76 |
model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
|
@@ -222,7 +223,7 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
|
|
222 |
], [progress_dummy, *interpretation_bubbles])
|
223 |
|
224 |
original_prompt_btn.click(get_hidden_states,
|
225 |
-
[original_prompt_raw],
|
226 |
[progress_dummy, *tokens_container, *interpretation_bubbles])
|
227 |
original_prompt_raw.change(lambda: [gr.Button(visible=False) for _ in range(MAX_PROMPT_TOKENS)], [], tokens_container)
|
228 |
|
|
|
68 |
+ [gr.Textbox('', visible=False) for _ in range(len(interpretation_bubbles))]
|
69 |
+ [gr.Button('', visible=False) for _ in range(len(tokens_container))]
|
70 |
+ [*extra_components])
|
71 |
+
|
72 |
+
|
73 |
@spaces.GPU
|
74 |
+
def get_hidden_states(global_state, raw_original_prompt):
|
75 |
model, tokenizer = global_state.model, global_state.tokenizer
|
76 |
original_prompt = global_state.original_prompt_template.format(prompt=raw_original_prompt)
|
77 |
model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
|
|
|
223 |
], [progress_dummy, *interpretation_bubbles])
|
224 |
|
225 |
original_prompt_btn.click(get_hidden_states,
|
226 |
+
[gr.State(global_state), original_prompt_raw],
|
227 |
[progress_dummy, *tokens_container, *interpretation_bubbles])
|
228 |
original_prompt_raw.change(lambda: [gr.Button(visible=False) for _ in range(MAX_PROMPT_TOKENS)], [], tokens_container)
|
229 |
|