Update app.py
Browse files
app.py
CHANGED
@@ -23,7 +23,8 @@ class GlobalState:
|
|
23 |
model : Optional[PreTrainedModel] = None
|
24 |
hidden_states : Optional[torch.Tensor] = None
|
25 |
interpretation_prompt_template : str = '{prompt}'
|
26 |
-
original_prompt_template : str = '{prompt}'
|
|
|
27 |
|
28 |
|
29 |
suggested_interpretation_prompts = [
|
@@ -46,6 +47,7 @@ def reset_model(model_name, *extra_components):
|
|
46 |
model_path = model_args.pop('model_path')
|
47 |
global_state.original_prompt_template = model_args.pop('original_prompt_template')
|
48 |
global_state.interpretation_prompt_template = model_args.pop('interpretation_prompt_template')
|
|
|
49 |
tokenizer_path = model_args.pop('tokenizer') if 'tokenizer' in model_args else model_path
|
50 |
use_ctransformers = model_args.pop('ctransformers', False)
|
51 |
AutoModelClass = CAutoModelForCausalLM if use_ctransformers else AutoModelForCausalLM
|
@@ -96,7 +98,7 @@ def run_interpretation(raw_interpretation_prompt, max_new_tokens, do_sample,
|
|
96 |
|
97 |
# create an InterpretationPrompt object from raw_interpretation_prompt (after putting it in the right template)
|
98 |
interpretation_prompt = global_state.interpretation_prompt_template.format(prompt=raw_interpretation_prompt, repeat=5)
|
99 |
-
interpretation_prompt = InterpretationPrompt(global_state.tokenizer, interpretation_prompt)
|
100 |
|
101 |
# generate the interpretations
|
102 |
# generate = generate_interpretation_gpu if use_gpu else lambda interpretation_prompt, *args, **kwargs: interpretation_prompt.generate(*args, **kwargs)
|
@@ -138,23 +140,24 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
|
|
138 |
|
139 |
gr.Markdown(
|
140 |
'''
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
''', line_breaks=True)
|
146 |
|
147 |
# with gr.Column(scale=1):
|
148 |
# gr.Markdown('<span style="font-size:180px;">π€</span>')
|
149 |
|
150 |
with gr.Group():
|
151 |
-
model_chooser = gr.Radio(label='Model', choices=list(model_info.keys()), value=model_name)
|
152 |
|
153 |
with gr.Blocks() as demo_blocks:
|
154 |
gr.Markdown('## Choose Your Interpretation Prompt')
|
155 |
with gr.Group('Interpretation'):
|
156 |
interpretation_prompt = gr.Text(suggested_interpretation_prompts[0], label='Interpretation Prompt')
|
157 |
-
gr.Examples([[p] for p in suggested_interpretation_prompts],
|
|
|
158 |
|
159 |
|
160 |
gr.Markdown('## The Prompt to Analyze')
|
@@ -198,8 +201,8 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
|
|
198 |
|
199 |
|
200 |
# event listeners
|
201 |
-
extra_components = [
|
202 |
-
|
203 |
model_chooser.change(reset_model, [model_chooser, *extra_components], extra_components)
|
204 |
|
205 |
for i, btn in enumerate(tokens_container):
|
|
|
23 |
model : Optional[PreTrainedModel] = None
|
24 |
hidden_states : Optional[torch.Tensor] = None
|
25 |
interpretation_prompt_template : str = '{prompt}'
|
26 |
+
original_prompt_template : str = 'User: [X]\n\nAnswer: {prompt}'
|
27 |
+
layers_format : str = 'model.layers.{k}'
|
28 |
|
29 |
|
30 |
suggested_interpretation_prompts = [
|
|
|
47 |
model_path = model_args.pop('model_path')
|
48 |
global_state.original_prompt_template = model_args.pop('original_prompt_template')
|
49 |
global_state.interpretation_prompt_template = model_args.pop('interpretation_prompt_template')
|
50 |
+
global_state.layers_format = model_args.pop('layers_format')
|
51 |
tokenizer_path = model_args.pop('tokenizer') if 'tokenizer' in model_args else model_path
|
52 |
use_ctransformers = model_args.pop('ctransformers', False)
|
53 |
AutoModelClass = CAutoModelForCausalLM if use_ctransformers else AutoModelForCausalLM
|
|
|
98 |
|
99 |
# create an InterpretationPrompt object from raw_interpretation_prompt (after putting it in the right template)
|
100 |
interpretation_prompt = global_state.interpretation_prompt_template.format(prompt=raw_interpretation_prompt, repeat=5)
|
101 |
+
interpretation_prompt = InterpretationPrompt(global_state.tokenizer, interpretation_prompt, layers_format=global_state.layers_format)
|
102 |
|
103 |
# generate the interpretations
|
104 |
# generate = generate_interpretation_gpu if use_gpu else lambda interpretation_prompt, *args, **kwargs: interpretation_prompt.generate(*args, **kwargs)
|
|
|
140 |
|
141 |
gr.Markdown(
|
142 |
'''
|
143 |
+
**πΎ The idea is really simple: models are able to understand their own hidden states by nature! πΎ**
|
144 |
+
In line with the residual stream view ([nostalgebraist, 2020](https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens)), internal representations from different layers are transferable between layers.
|
145 |
+
So we can inject an representation from (roughly) any layer into any layer! If we give a model a prompt of the form ``User: [X] Assistant: Sure'll I'll repeat your message`` and replace the internal representation of ``[X]`` *during computation* with the hidden state we want to understand,
|
146 |
+
we expect to get back a summary of the information that exists inside the hidden state, despite being from a different layer and a different run!! How cool is that! π―π―π―
|
147 |
''', line_breaks=True)
|
148 |
|
149 |
# with gr.Column(scale=1):
|
150 |
# gr.Markdown('<span style="font-size:180px;">π€</span>')
|
151 |
|
152 |
with gr.Group():
|
153 |
+
model_chooser = gr.Radio(label='Choose Your Model', choices=list(model_info.keys()), value=model_name)
|
154 |
|
155 |
with gr.Blocks() as demo_blocks:
|
156 |
gr.Markdown('## Choose Your Interpretation Prompt')
|
157 |
with gr.Group('Interpretation'):
|
158 |
interpretation_prompt = gr.Text(suggested_interpretation_prompts[0], label='Interpretation Prompt')
|
159 |
+
interpretation_prompt_examples = gr.Examples([[p] for p in suggested_interpretation_prompts],
|
160 |
+
[interpretation_prompt], cache_examples=False)
|
161 |
|
162 |
|
163 |
gr.Markdown('## The Prompt to Analyze')
|
|
|
201 |
|
202 |
|
203 |
# event listeners
|
204 |
+
extra_components = [interpretation_prompt, interpretation_prompt_examples, original_prompt_raw, *tokens_container,
|
205 |
+
original_prompt_btn, *interpretation_bubbles]
|
206 |
model_chooser.change(reset_model, [model_chooser, *extra_components], extra_components)
|
207 |
|
208 |
for i, btn in enumerate(tokens_container):
|