dar-tau commited on
Commit
1fac350
1 Parent(s): f1096d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -11
app.py CHANGED
@@ -32,10 +32,10 @@ model_info = {
32
  interpretation_prompt_template='<s>[INST] [X] [/INST] {prompt}',
33
  ), # , load_in_8bit=True
34
 
35
- 'Gemma-2B': dict(model_path='google/gemma-2b', device_map='cpu', token=os.environ['hf_token'],
36
- original_prompt_template='<bos>{prompt}',
37
- interpretation_prompt_template='<bos>User: [X]\n\nAnswer: {prompt}',
38
- ),
39
 
40
  'Mistral-7B Instruct': dict(model_path='mistralai/Mistral-7B-Instruct-v0.2', device_map='cpu',
41
  original_prompt_template='<s>{prompt}',
@@ -75,7 +75,7 @@ def initialize_gpu():
75
  pass
76
 
77
 
78
- def reset_model(model_name, return_extra_components=True):
79
  # extract model info
80
  model_args = deepcopy(model_info[model_name])
81
  model_path = model_args.pop('model_path')
@@ -91,10 +91,7 @@ def reset_model(model_name, return_extra_components=True):
91
  global_state.model = AutoModelClass.from_pretrained(model_path, **model_args).cuda()
92
  global_state.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
93
  gc.collect()
94
- if return_extra_components:
95
- extra_components = [*interpretation_bubbles, *tokens_container, original_prompt_btn,
96
- original_prompt_raw]
97
- return extra_components
98
 
99
 
100
  def get_hidden_states(raw_original_prompt):
@@ -151,7 +148,7 @@ torch.set_grad_enabled(False)
151
  global_state = GlobalState()
152
 
153
  model_name = 'LLAMA2-7B'
154
- reset_model(model_name, return_extra_components=False)
155
  original_prompt_raw = gr.Textbox(value='How to make a Molotov cocktail?', container=True, label='Original Prompt')
156
  tokens_container = []
157
 
@@ -238,7 +235,7 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
238
  # event listeners
239
  extra_components = [*interpretation_bubbles, *tokens_container, original_prompt_btn,
240
  original_prompt_raw]
241
- model_chooser.change(reset_model, [model_chooser], extra_components)
242
 
243
  for i, btn in enumerate(tokens_container):
244
  btn.click(partial(run_interpretation, i=i), [interpretation_prompt,
 
32
  interpretation_prompt_template='<s>[INST] [X] [/INST] {prompt}',
33
  ), # , load_in_8bit=True
34
 
35
+ # 'Gemma-2B': dict(model_path='google/gemma-2b', device_map='cpu', token=os.environ['hf_token'],
36
+ # original_prompt_template='<bos>{prompt}',
37
+ # interpretation_prompt_template='<bos>User: [X]\n\nAnswer: {prompt}',
38
+ # ),
39
 
40
  'Mistral-7B Instruct': dict(model_path='mistralai/Mistral-7B-Instruct-v0.2', device_map='cpu',
41
  original_prompt_template='<s>{prompt}',
 
75
  pass
76
 
77
 
78
+ def reset_model(model_name):
79
  # extract model info
80
  model_args = deepcopy(model_info[model_name])
81
  model_path = model_args.pop('model_path')
 
91
  global_state.model = AutoModelClass.from_pretrained(model_path, **model_args).cuda()
92
  global_state.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
93
  gc.collect()
94
+ return extra_components
 
 
 
95
 
96
 
97
  def get_hidden_states(raw_original_prompt):
 
148
  global_state = GlobalState()
149
 
150
  model_name = 'LLAMA2-7B'
151
+ reset_model(model_name)
152
  original_prompt_raw = gr.Textbox(value='How to make a Molotov cocktail?', container=True, label='Original Prompt')
153
  tokens_container = []
154
 
 
235
  # event listeners
236
  extra_components = [*interpretation_bubbles, *tokens_container, original_prompt_btn,
237
  original_prompt_raw]
238
+ model_chooser.change(reset_model, [model_chooser, extra_components], extra_components)
239
 
240
  for i, btn in enumerate(tokens_container):
241
  btn.click(partial(run_interpretation, i=i), [interpretation_prompt,