dar-tau commited on
Commit
45d9aa5
1 Parent(s): 7b7c573

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -24
app.py CHANGED
@@ -50,14 +50,19 @@ def get_hidden_states(raw_original_prompt):
50
  tokens = tokenizer.batch_decode(model_inputs.input_ids[0])
51
  outputs = model(**model_inputs, output_hidden_states=True, return_dict=True)
52
  hidden_states = torch.stack([h.squeeze(0).cpu().detach() for h in outputs.hidden_states], dim=0)
53
- # with gr.Row() as tokens_container:
54
- # for token in tokens:
55
- # gr.Button(token)
56
- return [gr.Button(tokens[i], visible=True) if i < len(tokens) else gr.Button('', visible=False) for i in range(MAX_PROMPT_TOKENS)]
57
-
58
-
59
- def run_model(raw_original_prompt, raw_interpretation_prompt, max_new_tokens, do_sample,
60
- temperature, top_k, top_p, repetition_penalty, length_penalty, num_beams=1):
 
 
 
 
 
61
 
62
  length_penalty = -length_penalty # unintuitively, length_penalty > 0 will make sequences longer, so we negate it
63
 
@@ -77,16 +82,9 @@ def run_model(raw_original_prompt, raw_interpretation_prompt, max_new_tokens, do
77
  interpretation_prompt = interpretation_prompt_template.format(prompt=raw_interpretation_prompt)
78
  interpretation_prompt = InterpretationPrompt(tokenizer, interpretation_prompt)
79
 
80
- # compute the hidden stated from the original prompt (after putting it in the right template)
81
- original_prompt = original_prompt_template.format(prompt=raw_original_prompt)
82
- model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
83
- outputs = model(**model_inputs, output_hidden_states=True, return_dict=True)
84
- hidden_states = torch.stack([h.squeeze(0).cpu().detach() for h in outputs.hidden_states], dim=0)
85
-
86
  # generate the interpretations
87
  generated = interpretation_prompt.generate(model, {0: hidden_states[:, -1]}, k=3, **generation_kwargs)
88
  generation_texts = tokenizer.batch_decode(generated)
89
- # tokens = [x.lstrip('▁') for x in tokenizer.tokenize(text)]
90
  return generation_texts
91
 
92
 
@@ -148,13 +146,7 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
148
  with gr.Group('Output'):
149
  with gr.Row():
150
  tokens_container = [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS)]
151
- with gr.Column() as interpretations_container:
152
- pass
153
-
154
- original_prompt_btn.click(get_hidden_states, [original_prompt_raw], [*tokens_container])
155
- # btn.click(run_model,
156
- # [text, interpretation_prompt, num_tokens, do_sample, temperature,
157
- # top_k, top_p, repetition_penalty, length_penalty],
158
- # [tokens_container])
159
-
160
  demo.launch()
 
50
  tokens = tokenizer.batch_decode(model_inputs.input_ids[0])
51
  outputs = model(**model_inputs, output_hidden_states=True, return_dict=True)
52
  hidden_states = torch.stack([h.squeeze(0).cpu().detach() for h in outputs.hidden_states], dim=0)
53
+ token_btns = []
54
+ for i, token in enumerate(tokens):
55
+ btn = gr.Button(token)
56
+ btn.click(partial(run_interpretation, interpreted_vectors=hidden_states[:, i]),
57
+ [interpretation_prompt, num_tokens, do_sample, temperature, top_k, top_p, repetition_penalty, length_penalty],
58
+ [json_output])
59
+ token_btns.append(btn)
60
+ token_btns += [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))]
61
+ return token_btns
62
+
63
+
64
+ def run_interpretation(raw_interpretation_prompt, max_new_tokens, do_sample,
65
+ temperature, top_k, top_p, repetition_penalty, length_penalty, interpreted_vectors, num_beams=1):
66
 
67
  length_penalty = -length_penalty # unintuitively, length_penalty > 0 will make sequences longer, so we negate it
68
 
 
82
  interpretation_prompt = interpretation_prompt_template.format(prompt=raw_interpretation_prompt)
83
  interpretation_prompt = InterpretationPrompt(tokenizer, interpretation_prompt)
84
 
 
 
 
 
 
 
85
  # generate the interpretations
86
  generated = interpretation_prompt.generate(model, {0: hidden_states[:, -1]}, k=3, **generation_kwargs)
87
  generation_texts = tokenizer.batch_decode(generated)
 
88
  return generation_texts
89
 
90
 
 
146
  with gr.Group('Output'):
147
  with gr.Row():
148
  tokens_container = [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS)]
149
+ json_output = gr.JSON()
150
+
151
+ original_prompt_btn.click(get_hidden_states, [original_prompt_raw], [*tokens_container])
 
 
 
 
 
 
152
  demo.launch()