dar-tau commited on
Commit
b30a06e
1 Parent(s): 9d7840a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -25
app.py CHANGED
@@ -45,25 +45,22 @@ suggested_interpretation_prompts = ["Before responding, let me repeat the messag
45
  def initialize_gpu():
46
  pass
47
 
48
- def get_hidden_states(raw_original_prompt, interpretation_args, interpretation_outputs):
49
  original_prompt = original_prompt_template.format(prompt=raw_original_prompt)
50
  model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
51
  tokens = tokenizer.batch_decode(model_inputs.input_ids[0])
52
  outputs = model(**model_inputs, output_hidden_states=True, return_dict=True)
53
  hidden_states = torch.stack([h.squeeze(0).cpu().detach() for h in outputs.hidden_states], dim=0)
54
- token_btns = []
55
- for i, token in enumerate(tokens):
56
- btn = gr.Button(token)
57
- btn.click(partial(run_interpretation, interpreted_vectors=hidden_states[:, i]),
58
- interpretation_args, interpretation_outputs)
59
- token_btns.append(btn)
60
- token_btns += [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))]
61
- return token_btns
62
-
63
-
64
- def run_interpretation(raw_interpretation_prompt, max_new_tokens, do_sample,
65
- temperature, top_k, top_p, repetition_penalty, length_penalty, interpreted_vectors, num_beams=1):
66
-
67
  length_penalty = -length_penalty # unintuitively, length_penalty > 0 will make sequences longer, so we negate it
68
 
69
  # generation parameters
@@ -83,7 +80,7 @@ def run_interpretation(raw_interpretation_prompt, max_new_tokens, do_sample,
83
  interpretation_prompt = InterpretationPrompt(tokenizer, interpretation_prompt)
84
 
85
  # generate the interpretations
86
- generated = interpretation_prompt.generate(model, {0: hidden_states[:, -1]}, k=3, **generation_kwargs)
87
  generation_texts = tokenizer.batch_decode(generated)
88
  return generation_texts
89
 
@@ -105,6 +102,8 @@ model = AutoModelClass.from_pretrained(model_name, **model_args)
105
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, token=os.environ['hf_token'])
106
 
107
  # demo
 
 
108
  with gr.Blocks(theme=gr.themes.Default()) as demo:
109
  with gr.Row():
110
  with gr.Column(scale=5):
@@ -144,15 +143,15 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
144
  interpretation_prompt = gr.Text(suggested_interpretation_prompts[0], label='Interpretation Prompt')
145
 
146
  with gr.Group('Output'):
 
147
  with gr.Row():
148
- tokens_container = [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS)]
149
- json_output = gr.JSON()
150
-
151
- interpretation_args = [interpretation_prompt, num_tokens, do_sample, temperature,
152
- top_k, top_p, repetition_penalty, length_penalty]
153
- interpretation_outputs = [json_output]
154
-
155
- original_prompt_btn.click(partial(get_hidden_states, interpretation_args=interpretation_args,
156
- interpretation_outputs=interpretation_outputs
157
- ), [original_prompt_raw], [*tokens_container])
158
  demo.launch()
 
45
  def initialize_gpu():
46
  pass
47
 
48
+ def get_hidden_states(raw_original_prompt):
49
  original_prompt = original_prompt_template.format(prompt=raw_original_prompt)
50
  model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
51
  tokens = tokenizer.batch_decode(model_inputs.input_ids[0])
52
  outputs = model(**model_inputs, output_hidden_states=True, return_dict=True)
53
  hidden_states = torch.stack([h.squeeze(0).cpu().detach() for h in outputs.hidden_states], dim=0)
54
+ token_btns = ([gr.Button(token, visible=True) for token in tokens]
55
+ + [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))])
56
+ return [hidden_state, *token_btns]
57
+
58
+
59
+ def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens, do_sample,
60
+ temperature, top_k, top_p, repetition_penalty, length_penalty, i,
61
+ num_beams=1):
62
+
63
+ interpreted_vectors = global_state[:, i]
 
 
 
64
  length_penalty = -length_penalty # unintuitively, length_penalty > 0 will make sequences longer, so we negate it
65
 
66
  # generation parameters
 
80
  interpretation_prompt = InterpretationPrompt(tokenizer, interpretation_prompt)
81
 
82
  # generate the interpretations
83
+ generated = interpretation_prompt.generate(model, {0: interpreted_vectors}, k=3, **generation_kwargs)
84
  generation_texts = tokenizer.batch_decode(generated)
85
  return generation_texts
86
 
 
102
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, token=os.environ['hf_token'])
103
 
104
  # demo
105
+ global_state = gr.State([])
106
+ json_output = gr.JSON()
107
  with gr.Blocks(theme=gr.themes.Default()) as demo:
108
  with gr.Row():
109
  with gr.Column(scale=5):
 
143
  interpretation_prompt = gr.Text(suggested_interpretation_prompts[0], label='Interpretation Prompt')
144
 
145
  with gr.Group('Output'):
146
+ tokens_container = []
147
  with gr.Row():
148
+ for _ in range(MAX_PROMPT_TOKENS):
149
+ btn = gr.Button('', visible=False)
150
+ btn.click(partial(run_interpretation, i=i), [global_state, interpretation_prompt, num_tokens, do_sample, temperature,
151
+ top_k, top_p, repetition_penalty, length_penalty
152
+ ], [json_output])
153
+ tokens_container.append(btn)
154
+ json_output.render()
155
+
156
+ original_prompt_btn.click(get_hidden_states, [original_prompt_raw], [global_state, *tokens_container])
 
157
  demo.launch()