dar-tau commited on
Commit
96ac9fb
1 Parent(s): 5a9ec75

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -75,7 +75,7 @@ def initialize_gpu():
75
  pass
76
 
77
 
78
- def reset_model(model_name, *extra_components):
79
  # extract model info
80
  model_args = deepcopy(model_info[model_name])
81
  model_path = model_args.pop('model_path')
@@ -113,7 +113,7 @@ def get_hidden_states(raw_original_prompt):
113
  def run_interpretation(raw_interpretation_prompt, max_new_tokens, do_sample,
114
  temperature, top_k, top_p, repetition_penalty, length_penalty, i,
115
  num_beams=1):
116
-
117
  interpreted_vectors = global_state.hidden_states[:, i]
118
  length_penalty = -length_penalty # unintuitively, length_penalty > 0 will make sequences longer, so we negate it
119
 
@@ -146,6 +146,7 @@ def run_interpretation(raw_interpretation_prompt, max_new_tokens, do_sample,
146
  ## main
147
  torch.set_grad_enabled(False)
148
  global_state = GlobalState()
 
149
 
150
  model_name = 'LLAMA2-7B'
151
  reset_model(model_name)
@@ -183,7 +184,7 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
183
  # gr.Markdown('<span style="font-size:180px;">🤔</span>')
184
 
185
  with gr.Group():
186
- model_chooser = gr.Radio(choices=list(model_info.keys()), value=model_name)
187
 
188
  with gr.Blocks() as demo_blocks:
189
  gr.Markdown('## Choose Your Interpretation Prompt')
@@ -233,9 +234,9 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
233
 
234
 
235
  # event listeners
236
- all_components = [*interpretation_bubbles, *tokens_container, original_prompt_btn,
237
  original_prompt_raw]
238
- model_chooser.change(reset_model, [model_chooser, *all_components], all_components)
239
 
240
  for i, btn in enumerate(tokens_container):
241
  btn.click(partial(run_interpretation, i=i), [interpretation_prompt,
 
75
  pass
76
 
77
 
78
+ def reset_model(model_name):
79
  # extract model info
80
  model_args = deepcopy(model_info[model_name])
81
  model_path = model_args.pop('model_path')
 
113
  def run_interpretation(raw_interpretation_prompt, max_new_tokens, do_sample,
114
  temperature, top_k, top_p, repetition_penalty, length_penalty, i,
115
  num_beams=1):
116
+ print(f'run {global_state.model}')
117
  interpreted_vectors = global_state.hidden_states[:, i]
118
  length_penalty = -length_penalty # unintuitively, length_penalty > 0 will make sequences longer, so we negate it
119
 
 
146
  ## main
147
  torch.set_grad_enabled(False)
148
  global_state = GlobalState()
149
+ extra_components = []
150
 
151
  model_name = 'LLAMA2-7B'
152
  reset_model(model_name)
 
184
  # gr.Markdown('<span style="font-size:180px;">🤔</span>')
185
 
186
  with gr.Group():
187
+ model_chooser = gr.Radio(label='Model', choices=list(model_info.keys()), value=model_name)
188
 
189
  with gr.Blocks() as demo_blocks:
190
  gr.Markdown('## Choose Your Interpretation Prompt')
 
234
 
235
 
236
  # event listeners
237
+ extra_components = [*interpretation_bubbles, *tokens_container, original_prompt_btn,
238
  original_prompt_raw]
239
+ model_chooser.change(reset_model, [model_chooser], extra_components)
240
 
241
  for i, btn in enumerate(tokens_container):
242
  btn.click(partial(run_interpretation, i=i), [interpretation_prompt,