Update app.py
Browse files
app.py
CHANGED
@@ -74,7 +74,7 @@ def initialize_gpu():
|
|
74 |
pass
|
75 |
|
76 |
|
77 |
-
def reset_model(model_name,
|
78 |
# extract model info
|
79 |
model_args = deepcopy(model_info[model_name])
|
80 |
model_path = model_args.pop('model_path')
|
@@ -90,10 +90,11 @@ def reset_model(model_name, global_state):
|
|
90 |
global_state.model = AutoModelClass.from_pretrained(model_path, **model_args).cuda()
|
91 |
global_state.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
|
92 |
gc.collect()
|
93 |
-
|
|
|
94 |
|
95 |
|
96 |
-
def get_hidden_states(
|
97 |
model, tokenizer = global_state.model, global_state.tokenizer
|
98 |
original_prompt = global_state.original_prompt_template.format(prompt=raw_original_prompt)
|
99 |
model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
|
@@ -105,11 +106,11 @@ def get_hidden_states(global_state, raw_original_prompt):
|
|
105 |
progress_dummy_output = ''
|
106 |
invisible_bubbles = [gr.Textbox('', visible=False) for i in range(len(interpretation_bubbles))]
|
107 |
global_state.hidden_states = hidden_states
|
108 |
-
return [progress_dummy_output,
|
109 |
|
110 |
|
111 |
@spaces.GPU
|
112 |
-
def run_interpretation(
|
113 |
temperature, top_k, top_p, repetition_penalty, length_penalty, i,
|
114 |
num_beams=1):
|
115 |
|
@@ -143,7 +144,7 @@ def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens,
|
|
143 |
|
144 |
|
145 |
## main
|
146 |
-
|
147 |
torch.set_grad_enabled(False)
|
148 |
model_name = 'LLAMA2-7B'
|
149 |
original_prompt_raw = gr.Textbox(value='How to make a Molotov cocktail?', container=True, label='Original Prompt')
|
@@ -153,7 +154,6 @@ for i in range(MAX_PROMPT_TOKENS):
|
|
153 |
tokens_container.append(btn)
|
154 |
|
155 |
with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
|
156 |
-
global_state = gr.State(reset_model(model_name, GlobalState()))
|
157 |
|
158 |
with gr.Row():
|
159 |
with gr.Column(scale=5):
|
@@ -236,8 +236,9 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
|
|
236 |
elem_classes=['bubble', 'even_bubble' if i % 2 == 0 else 'odd_bubble']
|
237 |
) for i in range(model.config.num_hidden_layers)]
|
238 |
|
|
|
239 |
# event listeners
|
240 |
-
model_chooser.change(reset_new_model, [model_chooser
|
241 |
|
242 |
for i, btn in enumerate(tokens_container):
|
243 |
btn.click(partial(run_interpretation, i=i), [global_state, interpretation_prompt,
|
@@ -247,6 +248,6 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
|
|
247 |
|
248 |
original_prompt_btn.click(get_hidden_states,
|
249 |
[original_prompt_raw],
|
250 |
-
[progress_dummy,
|
251 |
original_prompt_raw.change(lambda: [gr.Button(visible=False) for _ in range(MAX_PROMPT_TOKENS)], [], tokens_container)
|
252 |
demo.launch()
|
|
|
74 |
pass
|
75 |
|
76 |
|
77 |
+
def reset_model(model_name, return_state=False):
|
78 |
# extract model info
|
79 |
model_args = deepcopy(model_info[model_name])
|
80 |
model_path = model_args.pop('model_path')
|
|
|
90 |
global_state.model = AutoModelClass.from_pretrained(model_path, **model_args).cuda()
|
91 |
global_state.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
|
92 |
gc.collect()
|
93 |
+
if return_state:
|
94 |
+
return global_state
|
95 |
|
96 |
|
97 |
+
def get_hidden_states(raw_original_prompt):
|
98 |
model, tokenizer = global_state.model, global_state.tokenizer
|
99 |
original_prompt = global_state.original_prompt_template.format(prompt=raw_original_prompt)
|
100 |
model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
|
|
|
106 |
progress_dummy_output = ''
|
107 |
invisible_bubbles = [gr.Textbox('', visible=False) for i in range(len(interpretation_bubbles))]
|
108 |
global_state.hidden_states = hidden_states
|
109 |
+
return [progress_dummy_output, *token_btns, *invisible_bubbles]
|
110 |
|
111 |
|
112 |
@spaces.GPU
|
113 |
+
def run_interpretation(raw_interpretation_prompt, max_new_tokens, do_sample,
|
114 |
temperature, top_k, top_p, repetition_penalty, length_penalty, i,
|
115 |
num_beams=1):
|
116 |
|
|
|
144 |
|
145 |
|
146 |
## main
|
147 |
+
global_state = reset_model(model_name, return_state=True)
|
148 |
torch.set_grad_enabled(False)
|
149 |
model_name = 'LLAMA2-7B'
|
150 |
original_prompt_raw = gr.Textbox(value='How to make a Molotov cocktail?', container=True, label='Original Prompt')
|
|
|
154 |
tokens_container.append(btn)
|
155 |
|
156 |
with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
|
|
|
157 |
|
158 |
with gr.Row():
|
159 |
with gr.Column(scale=5):
|
|
|
236 |
elem_classes=['bubble', 'even_bubble' if i % 2 == 0 else 'odd_bubble']
|
237 |
) for i in range(model.config.num_hidden_layers)]
|
238 |
|
239 |
+
|
240 |
# event listeners
|
241 |
+
model_chooser.change(reset_new_model, [model_chooser], [])
|
242 |
|
243 |
for i, btn in enumerate(tokens_container):
|
244 |
btn.click(partial(run_interpretation, i=i), [global_state, interpretation_prompt,
|
|
|
248 |
|
249 |
original_prompt_btn.click(get_hidden_states,
|
250 |
[original_prompt_raw],
|
251 |
+
[progress_dummy, *tokens_container, *interpretation_bubbles])
|
252 |
original_prompt_raw.change(lambda: [gr.Button(visible=False) for _ in range(MAX_PROMPT_TOKENS)], [], tokens_container)
|
253 |
demo.launch()
|