File size: 12,540 Bytes
4d6d2dc
 
9b5c8c6
 
4d6d2dc
 
2533b7f
4d6d2dc
 
077e2b3
4d6d2dc
98858d4
4d6d2dc
 
fe6d32c
63981db
fe6d32c
 
b30e55a
 
 
4d6d2dc
f2d60cb
681bdc6
4d6d2dc
 
 
f2d60cb
4d6d2dc
 
 
 
f2d60cb
4d6d2dc
 
 
 
f2d60cb
 
 
 
 
 
4d6d2dc
 
 
 
 
 
 
 
1a614f9
 
 
 
af61663
4d6d2dc
 
5daf90b
4d6d2dc
 
b30a06e
 
de099ae
 
b30a06e
 
d8c5a8d
997caf4
 
d8c5a8d
 
2fcc96e
63981db
de099ae
b30a06e
 
4d6d2dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f2e668
4d6d2dc
 
 
ee8bd6d
5b9a8b1
4d6d2dc
de099ae
518abab
 
 
4d6d2dc
 
 
 
3e36699
4d6d2dc
 
 
3e36699
4d6d2dc
 
3e36699
4d6d2dc
 
 
 
98858d4
3e36699
4d6d2dc
3a2f9b3
b30a06e
765296c
81e5b58
765296c
643b640
765296c
 
c000f02
5e8b4c1
9a3579b
4872e11
c000f02
 
 
bca9264
765296c
c000f02
 
4872e11
c000f02
 
397b0a7
681bdc6
60f4a55
 
9d74583
681bdc6
 
17d4734
4489e5a
ada2e9d
4489e5a
 
765296c
 
b4d2f29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f2e668
b30e55a
4ece76d
edb0c67
4d6d2dc
 
0d6b098
9f98ca2
 
 
 
 
 
 
 
 
 
868605b
 
 
9f98ca2
868605b
 
 
9f98ca2
 
 
f833d09
868605b
 
ac01208
643b640
 
 
a79dee8
fe6d32c
a79dee8
 
e45f7c4
 
 
 
 
 
 
 
 
b30e55a
4d6d2dc
b30e55a
3a2f9b3
5e8b4c1
 
 
 
 
 
9f98ca2
5e8b4c1
 
868605b
dafad0d
 
 
4d6d2dc
 
 
 
 
 
 
 
 
 
 
 
 
643b640
 
 
 
 
 
 
 
dafad0d
5e8b4c1
15da79a
5e8b4c1
63981db
5e8b4c1
 
765296c
af61663
de099ae
4d6d2dc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
import os
from copy import deepcopy
from functools import partial
import spaces
import gradio as gr
import torch
from datasets import load_dataset
from ctransformers import AutoModelForCausalLM as CAutoModelForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer
from interpret import InterpretationPrompt

MAX_PROMPT_TOKENS = 60

## info
dataset_info = [{'name': 'Commonsense', 'hf_repo': 'tau/commonsense_qa', 'text_col': 'question'},
                {'name': 'Factual Recall', 'hf_repo': 'azhx/counterfact-filtered-gptj6b', 'text_col': 'subject+predicate', 
                 'filter': lambda x: x['label'] == 1},
               ]



model_info = {
    'LLAMA2-7B': dict(model_path='meta-llama/Llama-2-7b-chat-hf', device_map='cpu', token=os.environ['hf_token'], 
                                          original_prompt_template='<s>[INST] {prompt} [/INST]',
                                          interpretation_prompt_template='<s>[INST] [X] [/INST] {prompt}',
                                         ), # , load_in_8bit=True
    
    'Gemma-2B': dict(model_path='google/gemma-2b', device_map='cpu', token=os.environ['hf_token'],
                            original_prompt_template='<bos> {prompt}',
                            interpretation_prompt_template='<bos>User: [X]\n\nAnswer: {prompt}',
                           ),
    
    'Mistral-7B Instruct': dict(model_path='mistralai/Mistral-7B-Instruct-v0.2', device_map='cpu', 
                                               original_prompt_template='<s>[INST] {prompt} [/INST]',
                                               interpretation_prompt_template='<s>[INST] [X] [/INST] {prompt}',
                                              ),
    
    # 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF': dict(model_file='mistral-7b-instruct-v0.2.Q5_K_S.gguf', 
    #                                                tokenizer='mistralai/Mistral-7B-Instruct-v0.2',
    #                                                model_type='llama', hf=True, ctransformers=True,
    #                                                original_prompt_template='<s>[INST] {prompt} [/INST]',
    #                                                interpretation_prompt_template='<s>[INST] [X] [/INST] {prompt}',
    #                                               )
        }


suggested_interpretation_prompts = ["Before responding, let me repeat the message you wrote:", 
                                    "Let me repeat the message:", "Sure, I'll summarize your message:"]


## functions
@spaces.GPU
def initialize_gpu():
    pass

def get_hidden_states(raw_original_prompt, progress=gr.Progress()):
    original_prompt = original_prompt_template.format(prompt=raw_original_prompt)
    model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
    tokens = tokenizer.batch_decode(model_inputs.input_ids[0])
    outputs = model(**model_inputs, output_hidden_states=True, return_dict=True)
    hidden_states = torch.stack([h.squeeze(0).cpu().detach() for h in outputs.hidden_states], dim=0)
    token_btns = ([gr.Button(token, visible=True) for token in tokens] 
                  + [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))])
    progress_dummy_output = ''
    return [progress_dummy_output, hidden_states, *token_btns]


@spaces.GPU
def generate_interpretation_gpu(interpret_prompt, *args, **kwargs):
    return interpret_prompt.generate(*args, **kwargs)


def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens, do_sample, 
                       temperature, top_k, top_p, repetition_penalty, length_penalty, use_gpu, i, 
                       num_beams=1):

    interpreted_vectors = global_state[:, i]
    length_penalty = -length_penalty   # unintuitively, length_penalty > 0 will make sequences longer, so we negate it

    # generation parameters
    generation_kwargs = {
        'max_new_tokens': int(max_new_tokens),
        'do_sample': do_sample,
        'temperature': temperature,
        'top_k': int(top_k),
        'top_p': top_p,
        'repetition_penalty': repetition_penalty,
        'length_penalty': length_penalty,
        'num_beams': int(num_beams)
    }
    
    # create an InterpretationPrompt object from raw_interpretation_prompt (after putting it in the right template)
    interpretation_prompt = interpretation_prompt_template.format(prompt=raw_interpretation_prompt, repeat=5)
    interpretation_prompt = InterpretationPrompt(tokenizer, interpretation_prompt)

    # generate the interpretations
    generate = generate_interpretation_gpu if use_gpu else lambda interpretation_prompt, *args, **kwargs: interpretation_prompt.generate(*args, **kwargs)
    generated = generate(interpretation_prompt, model, {0: interpreted_vectors}, k=3, **generation_kwargs)
    generation_texts = tokenizer.batch_decode(generated)
    progress_dummy_output = ''
    return ([progress_dummy_output] + 
            [gr.Textbox(text.replace('\n', ' '), visible=True, container=False) for text in generation_texts]
           )


## main
torch.set_grad_enabled(False)
model_name = 'LLAMA2-7B'

# extract model info
model_args = deepcopy(model_info[model_name])
model_path = model_args.pop('model_path')
original_prompt_template = model_args.pop('original_prompt_template')
interpretation_prompt_template = model_args.pop('interpretation_prompt_template')
tokenizer_path = model_args.pop('tokenizer') if 'tokenizer' in model_args else model_path
use_ctransformers = model_args.pop('ctransformers', False)
AutoModelClass = CAutoModelForCausalLM if use_ctransformers else AutoModelForCausalLM

# get model
model = AutoModelClass.from_pretrained(model_path, **model_args).cuda()
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])

# demo
json_output = gr.JSON()
css = '''

.bubble {
  border: none
  border-radius: 10px;
  padding: 10px;
  margin-top: 15px;
  margin-left: 5%;
  width: 70%;
  box-shadow: 2px 2px 4px rgba(0, 0, 0, 0.3);
}

.even_bubble{
  background: pink;
}

.odd_bubble{
  background: skyblue;
}

.bubble textarea {
  border: none;
  box-shadow: none;
  background: inherit;
  resize: none;
}

.explanation_accordion .svelte-s1r2yt{
  font-weight: bold;
  text-align: start;
}

'''


# '''
# .token_btn{
#   background-color: none;
#   background: none;
#   border: none;
#   padding: 0;
#   font: inherit;
#   cursor: pointer;
#   color: blue; /* default text color */
#   font-weight: bold;
# }

# .token_btn:hover {
#     color: red;
# }

# '''

original_prompt_raw = gr.Textbox(value='How to make a Molotov cocktail?', container=True, label='Original Prompt')

with gr.Blocks(theme=gr.themes.Default(), css=css) as demo:
    global_state = gr.State([])
    with gr.Row():
        with gr.Column(scale=5):
            gr.Markdown('# 😎 Self-Interpreting Models')
            gr.Markdown(
                '**πŸ‘Ύ This space is a simple introduction to the emerging trend of models interpreting their OWN hidden states in free form natural language!!πŸ‘Ύ**',
                # elem_classes=['explanation_accordion']
            )
            gr.Markdown(
            '''This idea was investigated in the paper **Patchscopes** ([Ghandeharioun et al., 2024](https://arxiv.org/abs/2401.06102)) and was further explored in **SelfIE** ([Chen et al., 2024](https://arxiv.org/abs/2403.10949)). 
                An honorary mention of **Speaking Probes** ([Dar, 2023](https://towardsdatascience.com/speaking-probes-self-interpreting-models-7a3dc6cb33d6) - my own work πŸ₯³) which was less mature but had the same idea in mind. 
                We will follow the SelfIE implementation in this space for concreteness. Patchscopes are so general that they encompass many other interpretation techniques too!!! 
            ''', line_breaks=True)
            
            # gr.Markdown('**πŸ‘Ύ The idea is really simple: models are able to understand their own hidden states by nature! πŸ‘Ύ**',
            #               # elem_classes=['explanation_accordion']
            #             )  
            gr.Markdown(
            '''
            **πŸ‘Ύ The idea is really simple: models are able to understand their own hidden states by nature! πŸ‘Ύ**
            According to the residual stream view ([nostalgebraist, 2020](https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens)), internal representations from different layers are transferable between layers. 
            So we can inject an representation from (roughly) any layer to any layer! If I give a model a prompt of the form ``User: [X] Assistant: Sure'll I'll repeat your message`` and replace the internal representation of ``[X]`` *during computation* with the hidden state we want to understand, 
            we expect to get back a summary of the information that exists inside the hidden state. Since the model uses a roughly common latent space, it can understand representations from different layers and different runs!! How cool is that! 😯😯😯
            ''', line_breaks=True)
                
        # with gr.Column(scale=1):    
        #     gr.Markdown('<span style="font-size:180px;">πŸ€”</span>')

    with gr.Group('Interpretation'):
        interpretation_prompt = gr.Text(suggested_interpretation_prompts[0], label='Interpretation Prompt')

    gr.Markdown('''
    Here are some examples of prompts we can analyze their internal representations:
    ''')
    
    # for info in dataset_info:
    #     with gr.Tab(info['name']):
    #         num_examples = 10
    #         dataset = load_dataset(info['hf_repo'], split='train', streaming=True)
    #         if 'filter' in info:
    #             dataset = dataset.filter(info['filter'])
    #         dataset = dataset.shuffle(buffer_size=2000).take(num_examples)
    #         dataset = [[row[info['text_col']]] for row in dataset]
    #         gr.Examples(dataset, [original_prompt_raw])
                
    with gr.Group():
        original_prompt_raw.render()
        original_prompt_btn = gr.Button('Compute', variant='primary')

    tokens_container = []
    with gr.Row():
        for i in range(MAX_PROMPT_TOKENS):
            btn = gr.Button('', visible=False, elem_classes=['token_btn'])
            tokens_container.append(btn)
    use_gpu = False # gr.Checkbox(value=False, label='Use GPU')
    progress_dummy = gr.Markdown('', elem_id='progress_dummy')

    interpretation_bubbles = [gr.Textbox('', label=f'Layer {i}', container=False, visible=False, elem_classes=['bubble', 
                                                                                           'even_bubble' if i % 2 == 0 else 'odd_bubble'])
                             for i in range(model.config.num_hidden_layers)]
    
    with gr.Accordion(open=False, label='Settings'):
        with gr.Row():
            num_tokens = gr.Slider(1, 100, step=1, value=20, label='Max. # of Tokens')
            repetition_penalty = gr.Slider(1., 10., value=1, label='Repetition Penalty')
            length_penalty = gr.Slider(0, 5, value=0, label='Length Penalty')
            # num_beams = gr.Slider(1, 20, value=1, step=1, label='Number of Beams')
        do_sample = gr.Checkbox(label='With sampling')
        with gr.Accordion(label='Sampling Parameters'):
            with gr.Row():
                temperature = gr.Slider(0., 5., value=0.6, label='Temperature')
                top_k = gr.Slider(1, 1000, value=50, step=1, label='top k')
                top_p = gr.Slider(0., 1., value=0.95, label='top p')

    # with gr.Group():
    #     with gr.Row():
    #         for txt in model_info.keys():
    #             btn = gr.Button(txt)
    #             model_btns.append(btn)
    #         for btn in model_btns:
    #             btn.click(reset_new_model, [global_state])

    # event listeners
    for i, btn in enumerate(tokens_container):
        btn.click(partial(run_interpretation, i=i, use_gpu=use_gpu), [global_state, interpretation_prompt, 
                                                     num_tokens, do_sample, temperature, 
                                                     top_k, top_p, repetition_penalty, length_penalty,
                                                    ], [progress_dummy, *interpretation_bubbles])
    
    original_prompt_btn.click(get_hidden_states, 
                              [original_prompt_raw], 
                              [progress_dummy, global_state, *tokens_container])    
    demo.launch()