dar-tau commited on
Commit
049eed9
1 Parent(s): 03ac798

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -74,7 +74,7 @@ def initialize_gpu():
74
  pass
75
 
76
 
77
- def reset_model(model_name, return_state=False):
78
  # extract model info
79
  model_args = deepcopy(model_info[model_name])
80
  model_path = model_args.pop('model_path')
@@ -90,8 +90,6 @@ def reset_model(model_name, return_state=False):
90
  global_state.model = AutoModelClass.from_pretrained(model_path, **model_args).cuda()
91
  global_state.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
92
  gc.collect()
93
- if return_state:
94
- return global_state
95
 
96
 
97
  def get_hidden_states(raw_original_prompt):
@@ -145,11 +143,13 @@ def run_interpretation(raw_interpretation_prompt, max_new_tokens, do_sample,
145
 
146
  ## main
147
  torch.set_grad_enabled(False)
 
148
 
149
  model_name = 'LLAMA2-7B'
150
- global_state = reset_model(model_name, return_state=True)
151
  original_prompt_raw = gr.Textbox(value='How to make a Molotov cocktail?', container=True, label='Original Prompt')
152
  tokens_container = []
 
153
  for i in range(MAX_PROMPT_TOKENS):
154
  btn = gr.Button('', visible=False, elem_classes=['token_btn'])
155
  tokens_container.append(btn)
 
74
  pass
75
 
76
 
77
+ def reset_model(model_name):
78
  # extract model info
79
  model_args = deepcopy(model_info[model_name])
80
  model_path = model_args.pop('model_path')
 
90
  global_state.model = AutoModelClass.from_pretrained(model_path, **model_args).cuda()
91
  global_state.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
92
  gc.collect()
 
 
93
 
94
 
95
  def get_hidden_states(raw_original_prompt):
 
143
 
144
  ## main
145
  torch.set_grad_enabled(False)
146
+ global_state = GlobalState()
147
 
148
  model_name = 'LLAMA2-7B'
149
+ reset_model(model_name)
150
  original_prompt_raw = gr.Textbox(value='How to make a Molotov cocktail?', container=True, label='Original Prompt')
151
  tokens_container = []
152
+
153
  for i in range(MAX_PROMPT_TOKENS):
154
  btn = gr.Button('', visible=False, elem_classes=['token_btn'])
155
  tokens_container.append(btn)