dar-tau commited on
Commit
e1cea83
1 Parent(s): 9fa8328

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -9
app.py CHANGED
@@ -74,7 +74,7 @@ def initialize_gpu():
74
  pass
75
 
76
 
77
- def reset_model(model_name, global_state):
78
  # extract model info
79
  model_args = deepcopy(model_info[model_name])
80
  model_path = model_args.pop('model_path')
@@ -90,10 +90,11 @@ def reset_model(model_name, global_state):
90
  global_state.model = AutoModelClass.from_pretrained(model_path, **model_args).cuda()
91
  global_state.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
92
  gc.collect()
93
- return global_state
 
94
 
95
 
96
- def get_hidden_states(global_state, raw_original_prompt):
97
  model, tokenizer = global_state.model, global_state.tokenizer
98
  original_prompt = global_state.original_prompt_template.format(prompt=raw_original_prompt)
99
  model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
@@ -105,11 +106,11 @@ def get_hidden_states(global_state, raw_original_prompt):
105
  progress_dummy_output = ''
106
  invisible_bubbles = [gr.Textbox('', visible=False) for i in range(len(interpretation_bubbles))]
107
  global_state.hidden_states = hidden_states
108
- return [progress_dummy_output, global_state, *token_btns, *invisible_bubbles]
109
 
110
 
111
  @spaces.GPU
112
- def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens, do_sample,
113
  temperature, top_k, top_p, repetition_penalty, length_penalty, i,
114
  num_beams=1):
115
 
@@ -143,7 +144,7 @@ def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens,
143
 
144
 
145
  ## main
146
-
147
  torch.set_grad_enabled(False)
148
  model_name = 'LLAMA2-7B'
149
  original_prompt_raw = gr.Textbox(value='How to make a Molotov cocktail?', container=True, label='Original Prompt')
@@ -153,7 +154,6 @@ for i in range(MAX_PROMPT_TOKENS):
153
  tokens_container.append(btn)
154
 
155
  with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
156
- global_state = gr.State(reset_model(model_name, GlobalState()))
157
 
158
  with gr.Row():
159
  with gr.Column(scale=5):
@@ -236,8 +236,9 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
236
  elem_classes=['bubble', 'even_bubble' if i % 2 == 0 else 'odd_bubble']
237
  ) for i in range(model.config.num_hidden_layers)]
238
 
 
239
  # event listeners
240
- model_chooser.change(reset_new_model, [model_chooser, global_state], [global_state])
241
 
242
  for i, btn in enumerate(tokens_container):
243
  btn.click(partial(run_interpretation, i=i), [global_state, interpretation_prompt,
@@ -247,6 +248,6 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
247
 
248
  original_prompt_btn.click(get_hidden_states,
249
  [original_prompt_raw],
250
- [progress_dummy, global_state, *tokens_container, *interpretation_bubbles])
251
  original_prompt_raw.change(lambda: [gr.Button(visible=False) for _ in range(MAX_PROMPT_TOKENS)], [], tokens_container)
252
  demo.launch()
 
74
  pass
75
 
76
 
77
+ def reset_model(model_name, return_state=False):
78
  # extract model info
79
  model_args = deepcopy(model_info[model_name])
80
  model_path = model_args.pop('model_path')
 
90
  global_state.model = AutoModelClass.from_pretrained(model_path, **model_args).cuda()
91
  global_state.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
92
  gc.collect()
93
+ if return_state:
94
+ return global_state
95
 
96
 
97
+ def get_hidden_states(raw_original_prompt):
98
  model, tokenizer = global_state.model, global_state.tokenizer
99
  original_prompt = global_state.original_prompt_template.format(prompt=raw_original_prompt)
100
  model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
 
106
  progress_dummy_output = ''
107
  invisible_bubbles = [gr.Textbox('', visible=False) for i in range(len(interpretation_bubbles))]
108
  global_state.hidden_states = hidden_states
109
+ return [progress_dummy_output, *token_btns, *invisible_bubbles]
110
 
111
 
112
  @spaces.GPU
113
+ def run_interpretation(raw_interpretation_prompt, max_new_tokens, do_sample,
114
  temperature, top_k, top_p, repetition_penalty, length_penalty, i,
115
  num_beams=1):
116
 
 
144
 
145
 
146
  ## main
147
+ global_state = reset_model(model_name, return_state=True)
148
  torch.set_grad_enabled(False)
149
  model_name = 'LLAMA2-7B'
150
  original_prompt_raw = gr.Textbox(value='How to make a Molotov cocktail?', container=True, label='Original Prompt')
 
154
  tokens_container.append(btn)
155
 
156
  with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
 
157
 
158
  with gr.Row():
159
  with gr.Column(scale=5):
 
236
  elem_classes=['bubble', 'even_bubble' if i % 2 == 0 else 'odd_bubble']
237
  ) for i in range(model.config.num_hidden_layers)]
238
 
239
+
240
  # event listeners
241
+ model_chooser.change(reset_new_model, [model_chooser], [])
242
 
243
  for i, btn in enumerate(tokens_container):
244
  btn.click(partial(run_interpretation, i=i), [global_state, interpretation_prompt,
 
248
 
249
  original_prompt_btn.click(get_hidden_states,
250
  [original_prompt_raw],
251
+ [progress_dummy, *tokens_container, *interpretation_bubbles])
252
  original_prompt_raw.change(lambda: [gr.Button(visible=False) for _ in range(MAX_PROMPT_TOKENS)], [], tokens_container)
253
  demo.launch()