File size: 9,417 Bytes
4d6d2dc
 
9b5c8c6
 
4d6d2dc
 
 
 
077e2b3
4d6d2dc
3a2f9b3
4d6d2dc
 
 
 
681bdc6
4d6d2dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a614f9
 
 
 
a552026
4d6d2dc
 
5daf90b
4d6d2dc
 
b30a06e
 
2a19f0c
b30a06e
 
2fcc96e
b30a06e
2fcc96e
b30a06e
 
4d6d2dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b30a06e
4d6d2dc
f1ed2e8
4d6d2dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a2f9b3
b30a06e
765296c
 
 
 
 
e9a766f
bca9264
765296c
681bdc6
 
bca9264
681bdc6
 
765296c
 
b4d2f29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ece76d
edb0c67
4d6d2dc
 
 
4725944
4d6d2dc
34b25c9
4d6d2dc
a552026
34b25c9
4d6d2dc
 
 
 
 
 
 
 
 
a552026
3a2f9b3
4d6d2dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d7840a
4d6d2dc
b30a06e
3a2f9b3
bb22c33
4ece76d
ab00aa7
f1ed2e8
a552026
2fcc96e
765296c
2fcc96e
a552026
765296c
 
 
a552026
765296c
4d6d2dc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import os
from copy import deepcopy
from functools import partial
import spaces
import gradio as gr
import torch
from ctransformers import AutoModelForCausalLM as CAutoModelForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer
from interpret import InterpretationPrompt

MAX_PROMPT_TOKENS = 30

## info
model_info = {
    'meta-llama/Llama-2-7b-chat-hf': dict(device_map='cpu', token=os.environ['hf_token'], 
                                          original_prompt_template='<s>[INST] {prompt} [/INST]',
                                          interpretation_prompt_template='<s>[INST] [X] [/INST] {prompt}',
                                         ), # , load_in_8bit=True
    
    'google/gemma-2b': dict(device_map='cpu', token=os.environ['hf_token'],
                            original_prompt_template='<bos> {prompt}',
                            interpretation_prompt_template='<bos>User: [X]\n\nAnswer: {prompt}',
                           ),
    
    'mistralai/Mistral-7B-Instruct-v0.2': dict(device_map='cpu', 
                                               original_prompt_template='<s>[INST] {prompt} [/INST]',
                                               interpretation_prompt_template='<s>[INST] [X] [/INST] {prompt}',
                                              ),
    
    'TheBloke/Mistral-7B-Instruct-v0.2-GGUF': dict(model_file='mistral-7b-instruct-v0.2.Q5_K_S.gguf', 
                                                   tokenizer='mistralai/Mistral-7B-Instruct-v0.2',
                                                   model_type='llama', hf=True, ctransformers=True,
                                                   original_prompt_template='<s>[INST] {prompt} [/INST]',
                                                   interpretation_prompt_template='<s>[INST] [X] [/INST] {prompt}',
                                                  )
        }


suggested_interpretation_prompts = ["Before responding, let me repeat the message you wrote:", 
                                    "Let me repeat the message:", "Sure, I'll summarize your message:"]


## functions
@spaces.GPU
def initialize_gpu():
    pass

def get_hidden_states(progress, raw_original_prompt):
    original_prompt = original_prompt_template.format(prompt=raw_original_prompt)
    model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
    tokens = tokenizer.batch_decode(model_inputs.input_ids[0])
    outputs = model(**model_inputs, output_hidden_states=True, return_dict=True)
    hidden_states = torch.stack([h.squeeze(0).cpu().detach() for h in outputs.hidden_states], dim=0)
    token_btns = ([gr.Button(token, visible=True) for token in tokens] 
                  + [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))])
    return [hidden_states, *token_btns]


def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens, do_sample, 
                       temperature, top_k, top_p, repetition_penalty, length_penalty, i, 
                       num_beams=1, progress=gr.Progress()):

    interpreted_vectors = global_state[:, i]
    length_penalty = -length_penalty   # unintuitively, length_penalty > 0 will make sequences longer, so we negate it

    # generation parameters
    generation_kwargs = {
        'max_new_tokens': int(max_new_tokens),
        'do_sample': do_sample,
        'temperature': temperature,
        'top_k': int(top_k),
        'top_p': top_p,
        'repetition_penalty': repetition_penalty,
        'length_penalty': length_penalty,
        'num_beams': int(num_beams)
    }
    
    # create an InterpretationPrompt object from raw_interpretation_prompt (after putting it in the right template)
    interpretation_prompt = interpretation_prompt_template.format(prompt=raw_interpretation_prompt)
    interpretation_prompt = InterpretationPrompt(tokenizer, interpretation_prompt)

    # generate the interpretations
    generated = interpretation_prompt.generate(model, {0: interpreted_vectors}, k=3, **generation_kwargs)
    generation_texts = tokenizer.batch_decode(generated)
    return [gr.Textbox(text, visible=True, container=False) for text in generation_texts]


## main
torch.set_grad_enabled(False)
model_name = 'meta-llama/Llama-2-7b-chat-hf' # 'mistralai/Mistral-7B-Instruct-v0.2' #

# extract model info
model_args = deepcopy(model_info[model_name])
original_prompt_template = model_args.pop('original_prompt_template')
interpretation_prompt_template = model_args.pop('interpretation_prompt_template')
tokenizer_name = model_args.pop('tokenizer') if 'tokenizer' in model_args else model_name
use_ctransformers = model_args.pop('ctransformers', False)
AutoModelClass = CAutoModelForCausalLM if use_ctransformers else AutoModelForCausalLM

# get model
model = AutoModelClass.from_pretrained(model_name, **model_args)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, token=os.environ['hf_token'])

# demo
json_output = gr.JSON()
css = '''
.bubble {
  border: 2px solid #000;
  border-radius: 10px;
  padding: 10px;
  margin-top: 10px;
  background: pink;
}
.bubble > textarea{
  border: none;
  background: pink;
}

'''


# '''
# .token_btn{
#   background-color: none;
#   background: none;
#   border: none;
#   padding: 0;
#   font: inherit;
#   cursor: pointer;
#   color: blue; /* default text color */
#   font-weight: bold;
# }

# .token_btn:hover {
#     color: red;
# }

# '''

with gr.Blocks(theme=gr.themes.Default(), css=css) as demo:
    global_state = gr.State([])
    with gr.Row():
        with gr.Column(scale=5):
            gr.Markdown('''
                # 😎 Self-Interpreting Models
                
                πŸ‘Ύ **This space is a simple introduction to the emerging trend of models interpreting their _own hidden states_ in free form natural language**!! πŸ‘Ύ
                This idea was explored in the paper **Patchscopes** ([Ghandeharioun et al., 2024](https://arxiv.org/abs/2401.06102)) and was later investigated further in **SelfIE** ([Chen et al., 2024](https://arxiv.org/abs/2403.10949)). 
                An honorary mention of **Speaking Probes** ([Dar, 2023](https://towardsdatascience.com/speaking-probes-self-interpreting-models-7a3dc6cb33d6) -- my own work!! πŸ₯³) which was less mature but had the same idea in mind. 
                We will follow the SelfIE implementation in this space for concreteness. Patchscopes are so general that they encompass many other interpretation techniques too!!! 
                
                πŸ‘Ύ **The idea is really simple: models are able to understand their own hidden states by nature!** πŸ‘Ύ  
                If I give a model a prompt of the form ``User: [X] Assistant: Sure'll I'll repeat your message`` and replace ``[X]`` *during computation* with the hidden state we want to understand, 
                we hope to get back a summary of the information that exists inside the hidden state, because it is encoded in a latent space the model uses itself!! How cool is that! 😯😯😯
                ''', line_breaks=True)
        with gr.Column(scale=1):    
            gr.Markdown('<span style="font-size:180px;">πŸ€”</span>')
        
    with gr.Group():
        original_prompt_raw = gr.Textbox(value='Should I eat cake or vegetables?', container=True, label='Original Prompt')
        original_prompt_btn = gr.Button('Compute', variant='primary')
        
    with gr.Accordion(open=False, label='Settings'):
        with gr.Row():
            num_tokens = gr.Slider(1, 100, step=1, value=20, label='Max. # of Tokens')
            repetition_penalty = gr.Slider(1., 10., value=1, label='Repetition Penalty')
            length_penalty = gr.Slider(0, 5, value=0, label='Length Penalty')
            # num_beams = gr.Slider(1, 20, value=1, step=1, label='Number of Beams')
        do_sample = gr.Checkbox(label='With sampling')
        with gr.Accordion(label='Sampling Parameters'):
            with gr.Row():
                temperature = gr.Slider(0., 5., value=0.6, label='Temperature')
                top_k = gr.Slider(1, 1000, value=50, step=1, label='top k')
                top_p = gr.Slider(0., 1., value=0.95, label='top p')

    with gr.Group('Interpretation'):
        interpretation_prompt = gr.Text(suggested_interpretation_prompts[0], label='Interpretation Prompt')

    with gr.Group('Output'):
        tokens_container = []
        with gr.Row():
            for i in range(MAX_PROMPT_TOKENS):
                btn = gr.Button('', visible=False, elem_classes=['token_btn'])
                tokens_container.append(btn)
            interpretation_bubbles = [gr.Textbox('', container=False, visible=False, elem_classes=['bubble'])
                                     for i in range(model.config.num_hidden_layers)]
            
        for i, btn in enumerate(tokens_container):
            btn.click(partial(run_interpretation, i=i), [global_state, interpretation_prompt, num_tokens, do_sample, temperature, 
                                                         top_k, top_p, repetition_penalty, length_penalty
                                                              ], [*interpretation_bubbles])
        
    original_prompt_btn.click(get_hidden_states, 
                              [progress, original_prompt_raw], 
                              [global_state, *tokens_container])    
    demo.launch()