dar-tau commited on
Commit
cee7c56
1 Parent(s): 48e731a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -13,6 +13,7 @@ from transformers import PreTrainedModel, PreTrainedTokenizer, AutoModelForCausa
13
  from interpret import InterpretationPrompt
14
 
15
  MAX_PROMPT_TOKENS = 60
 
16
 
17
 
18
  ## info
@@ -102,7 +103,7 @@ def get_hidden_states(raw_original_prompt):
102
  token_btns = ([gr.Button(token, visible=True) for token in tokens]
103
  + [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))])
104
  progress_dummy_output = ''
105
- invisible_bubbles = [gr.Textbox('', visible=False) for i in range(len(interpretation_bubbles))]
106
  global_state.hidden_states = hidden_states
107
  return [progress_dummy_output, *token_btns, *invisible_bubbles]
108
 
@@ -136,9 +137,9 @@ def run_interpretation(raw_interpretation_prompt, max_new_tokens, do_sample,
136
  generated = interpretation_prompt.generate(global_state.model, {0: interpreted_vectors}, k=3, **generation_kwargs)
137
  generation_texts = tokenizer.batch_decode(generated)
138
  progress_dummy_output = ''
139
- return ([progress_dummy_output] +
140
- [gr.Textbox(text.replace('\n', ' '), visible=True, container=False, label=f'Layer {i}') for text in generation_texts]
141
- )
142
 
143
 
144
  ## main
@@ -235,7 +236,7 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
235
  progress_dummy = gr.Markdown('', elem_id='progress_dummy')
236
  interpretation_bubbles = [gr.Textbox('', container=False, visible=False,
237
  elem_classes=['bubble', 'even_bubble' if i % 2 == 0 else 'odd_bubble']
238
- ) for i in range(model.config.num_hidden_layers)]
239
 
240
 
241
  # event listeners
 
13
  from interpret import InterpretationPrompt
14
 
15
  MAX_PROMPT_TOKENS = 60
16
+ MAX_NUM_LAYERS = 50
17
 
18
 
19
  ## info
 
103
  token_btns = ([gr.Button(token, visible=True) for token in tokens]
104
  + [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))])
105
  progress_dummy_output = ''
106
+ invisible_bubbles = [gr.Textbox('', visible=False) for i in range(MAX_NUM_LAYERS)]
107
  global_state.hidden_states = hidden_states
108
  return [progress_dummy_output, *token_btns, *invisible_bubbles]
109
 
 
137
  generated = interpretation_prompt.generate(global_state.model, {0: interpreted_vectors}, k=3, **generation_kwargs)
138
  generation_texts = tokenizer.batch_decode(generated)
139
  progress_dummy_output = ''
140
+ bubble_outputs = [gr.Textbox(text.replace('\n', ' '), visible=True, container=False, label=f'Layer {i}') for text in generation_texts]
141
+ bubble_outputs += [gr.Textbox(visible=False) for _ in range(MAX_NUM_LAYERS - len(bubble_outputs))]
142
+ return [progress_dummy_output, *bubble_outputs]
143
 
144
 
145
  ## main
 
236
  progress_dummy = gr.Markdown('', elem_id='progress_dummy')
237
  interpretation_bubbles = [gr.Textbox('', container=False, visible=False,
238
  elem_classes=['bubble', 'even_bubble' if i % 2 == 0 else 'odd_bubble']
239
+ ) for i in range(MAX_NUM_LAYERS)]
240
 
241
 
242
  # event listeners