File size: 16,218 Bytes
4d6d2dc 9fa8328 4d6d2dc 9b5c8c6 3dbd18e 9b5c8c6 4d6d2dc 293df90 2533b7f 9fa8328 e4c230b 3dbd18e 077e2b3 79df09c 4d6d2dc 98858d4 cee7c56 2a69d25 4d6d2dc af967c9 e3b129c 9fa8328 e4c230b e3b129c f8fba1a 9fa8328 5ba44ad fa45463 9fa8328 f5ff0e3 d2266c9 5ba44ad f5ff0e3 9fa8328 4d6d2dc 1a614f9 4009e7f 9fa8328 fa45463 9fa8328 b233c7d f8fba1a 9fa8328 8f43d2f 9fa8328 b233c7d e4c230b b9bab55 b233c7d b9bab55 9fa8328 b5a6906 4009e7f fac9749 6b69a3c f269195 4009e7f 3e684af 4009e7f 9fa8328 4d6d2dc 5daf90b 11b86b4 f8fba1a 03b5112 f8fba1a 11b86b4 f8fba1a ee7058f af967c9 e4cb9e0 f8fba1a ee7058f b30a06e de099ae cee7c56 c23388b d8c5a8d 20c0832 4009e7f 1e4e3c2 de099ae d75586b 1e4e3c2 f8fba1a 11b86b4 e3b129c ce07d7a 4d6d2dc 9fa8328 d75586b 4d6d2dc d75586b b5a6906 2bb573c d75586b d2266c9 e4c230b 0a22698 e36b100 f8f26a8 0a22698 a5aded9 7889ca8 f724621 e524716 7889ca8 0a22698 9ab090f e4c230b d2266c9 8e5b8b3 de099ae d2266c9 b9e0369 542759c 8e5b8b3 11f2e9c fac9749 cee7c56 4d6d2dc 52186dc 5ba44ad cf4e80d 049eed9 d028e6b ee7058f d028e6b b30e55a 26b3274 4009e7f 4d6d2dc 0d6b098 5ae57a2 9f98ca2 b29377d 52186dc 9f98ca2 868605b fa45463 d2266c9 fa45463 9f98ca2 6b4003c 868605b 9fa8328 4d6d2dc 8670d11 62bd403 2a69d25 d2266c9 01e48f0 5ba44ad 01e48f0 5ba44ad 01e48f0 f724621 4e62c15 5ba44ad 4e62c15 5ba44ad 273c292 4d6d2dc 01e48f0 ae20803 21ccb1f 3e51fa7 ae20803 01e48f0 0a22698 c23388b 9136f03 5e8b4c1 4009e7f e4068da 1e4e3c2 e4068da 9136f03 765296c 4009e7f c23388b 5ba44ad c7e88d8 5ba44ad 8670d11 c7e88d8 4d6d2dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 |
import os
import gc
from typing import Optional
from dataclasses import dataclass
from copy import deepcopy
from functools import partial
import numpy as np
import spaces
import gradio as gr
import torch
import torch.nn.functional as F
from datasets import load_dataset
from transformers import PreTrainedModel, PreTrainedTokenizer, AutoModelForCausalLM, AutoTokenizer
from sentence_transformers import SentenceTransformer
from ctransformers import AutoModelForCausalLM as CAutoModelForCausalLM
from interpret import InterpretationPrompt
from configs import model_info, dataset_info
MAX_PROMPT_TOKENS = 60
MAX_NUM_LAYERS = 50
welcome_message = '**You are now running {model_name}!!** π₯³π₯³π₯³'
# Used by the layer and token importance heuristic in this file.
# These layers are usually not important. We will ignore them when looking for important layers
avoid_first, avoid_last = 3, 2
@dataclass
class LocalState:
hidden_states: Optional[torch.Tensor] = None
@dataclass
class GlobalState:
tokenizer : Optional[PreTrainedTokenizer] = None
model : Optional[PreTrainedModel] = None
sentence_transformer: Optional[PreTrainedModel] = None
local_state : LocalState = LocalState()
wait_with_hidden_state : bool = False
interpretation_prompt_template : str = '{prompt}'
original_prompt_template : str = 'User: [X]\n\nAssistant: {prompt}'
layers_format : str = 'model.layers.{k}'
suggested_interpretation_prompts = [
"Sure, I'll summarize your message:",
"The meaning of [X] is",
"Sure, here's a bullet list of the key words in your message:",
"Sure, here are the words in your message:",
"Before responding, let me repeat the message you wrote:",
"Let me repeat the message:"
]
## functions
@spaces.GPU
def initialize_gpu():
pass
def reset_model(model_name, load_on_gpu, *extra_components, reset_sentence_transformer=False, with_extra_components=True):
# extract model info
model_args = deepcopy(model_info[model_name])
model_path = model_args.pop('model_path')
global_state.original_prompt_template = model_args.pop('original_prompt_template')
global_state.interpretation_prompt_template = model_args.pop('interpretation_prompt_template')
global_state.layers_format = model_args.pop('layers_format')
tokenizer_path = model_args.pop('tokenizer') if 'tokenizer' in model_args else model_path
use_ctransformers = model_args.pop('ctransformers', False)
dont_cuda = model_args.pop('dont_cuda', False)
global_state.wait_with_hidden_states = model_args.pop('wait_with_hidden_states', False)
AutoModelClass = CAutoModelForCausalLM if use_ctransformers else AutoModelForCausalLM
# get model
global_state.model, global_state.tokenizer, global_state.local_state.hidden_states = None, None, None
gc.collect()
global_state.model = AutoModelClass.from_pretrained(model_path, **model_args)
if reset_sentence_transformer:
global_state.sentence_transformer = SentenceTransformer('all-MiniLM-L6-v2')
gc.collect()
if not dont_cuda:
global_state.model.to('cuda')
if load_on_gpu:
global_state.model.to('cpu')
global_state.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
gc.collect()
if with_extra_components:
return ([welcome_message.format(model_name=model_name)]
+ [gr.Textbox('', visible=False) for _ in range(len(interpretation_bubbles))]
+ [gr.Button('', visible=False) for _ in range(len(tokens_container))]
+ [*extra_components])
else:
return None
def get_hidden_states(raw_original_prompt, force_hidden_states=False):
model, tokenizer = global_state.model, global_state.tokenizer
original_prompt = global_state.original_prompt_template.format(prompt=raw_original_prompt)
model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
tokens = tokenizer.batch_decode(model_inputs.input_ids[0])
if global_state.wait_with_hidden_states and not force_hidden_states:
global_state.local_state.hidden_states = None
important_tokens = [] # cannot find important tokens without the hidden states
else:
outputs = model(**model_inputs, output_hidden_states=True, return_dict=True)
hidden_states = torch.stack([h.squeeze(0).cpu().detach() for h in outputs.hidden_states], dim=0)
# TODO: document this!
hidden_scores = F.normalize(hidden_states[avoid_first-1:len(hidden_states)-avoid_last], dim=-1).diff(dim=0).norm(dim=-1).cpu() # num_layers x num_tokens
important_tokens = np.unravel_index(hidden_scores.flatten().topk(k=5).indices.numpy(), hidden_scores.shape)[1]
print(f'{important_tokens=}\t\t{hidden_states.shape=}')
global_state.local_state.hidden_states = hidden_states.cpu().detach()
token_btns = ([gr.Button(token, visible=True,
elem_classes=['token_btn'] + (['important_token'] if i in important_tokens else [])
)
for i, token in enumerate(tokens)]
+ [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))])
progress_dummy_output = ''
invisible_bubbles = [gr.Textbox('', visible=False) for i in range(MAX_NUM_LAYERS)]
return [progress_dummy_output, *token_btns, *invisible_bubbles]
@spaces.GPU
def run_interpretation(raw_original_prompt, raw_interpretation_prompt, max_new_tokens, do_sample,
temperature, top_k, top_p, repetition_penalty, length_penalty, use_gpu, i,
num_beams=1):
model = global_state.model
tokenizer = global_state.tokenizer
print(f'run {model}')
if use_gpu:
model = model.cuda()
else:
model = model.cpu()
if global_state.wait_with_hidden_states and global_state.local_state.hidden_states is None:
get_hidden_states(raw_original_prompt, force_hidden_states=True)
interpreted_vectors = torch.tensor(global_state.local_state.hidden_states[:, i]).to(model.device).to(model.dtype)
hidden_means = torch.tensor(global_state.local_state.hidden_states.mean(dim=1)).to(model.device).to(model.dtype)
length_penalty = -length_penalty # unintuitively, length_penalty > 0 will make sequences longer, so we negate it
# generation parameters
generation_kwargs = {
'max_new_tokens': int(max_new_tokens),
'do_sample': do_sample,
'temperature': temperature,
'top_k': int(top_k),
'top_p': top_p,
'repetition_penalty': repetition_penalty,
'length_penalty': length_penalty,
'num_beams': int(num_beams)
}
# create an InterpretationPrompt object from raw_interpretation_prompt (after putting it in the right template)
interpretation_prompt = global_state.interpretation_prompt_template.format(prompt=raw_interpretation_prompt, repeat=5)
interpretation_prompt = InterpretationPrompt(tokenizer, interpretation_prompt)
# generate the interpretations
generated = interpretation_prompt.generate(model, {0: interpreted_vectors},
layers_format=global_state.layers_format, k=3,
**generation_kwargs)
generation_texts = tokenizer.batch_decode(generated)
# try identifying important layers
vectors_to_compare = interpreted_vectors # torch.tensor(global_state.sentence_transformer.encode(generation_texts))
diff_score1 = F.normalize(vectors_to_compare, dim=-1).diff(dim=0).norm(dim=-1).cpu()
tokenized_generations = [tokenizer.tokenize(text) for text in generation_texts]
bags_of_words = [set(tokens) | set([(tokens[i], tokens[i+1]) for i in range(len(tokens)-1)]) for tokens in tokenized_generations]
diff_score2 = torch.tensor([
-len(bags_of_words[i+1] & bags_of_words[i]) / np.sqrt(len(bags_of_words[i+1]) * len(bags_of_words[i]))
for i in range(len(bags_of_words)-1)
])
diff_score = ((diff_score1 - diff_score1.min()) / (diff_score1.max() - diff_score1.min())
+ (diff_score2 - diff_score2.min()) / (diff_score2.max() - diff_score2.min()))
assert avoid_first >= 1 # due to .diff() we will not be able to compute a score for the first layer
diff_score = diff_score[avoid_first-1:len(diff_score)-avoid_last]
important_idxs = avoid_first + diff_score.topk(k=int(np.ceil(0.3 * len(diff_score)))).indices.cpu().numpy() #
# create GUI output
print(f'{important_idxs=}')
progress_dummy_output = ''
elem_classes = [['bubble', 'even_bubble' if i % 2 == 0 else 'odd_bubble'] +
([] if i in important_idxs else ['faded_bubble']) for i in range(len(generation_texts))]
bubble_outputs = [gr.Textbox(text.replace('\n', ' '), show_label=True, visible=True,
container=True, label=f'Layer {i}', elem_classes=elem_classes[i])
for i, text in enumerate(generation_texts)]
bubble_outputs += [gr.Textbox('', visible=False) for _ in range(MAX_NUM_LAYERS - len(bubble_outputs))]
return [progress_dummy_output, *bubble_outputs]
## main
torch.set_grad_enabled(False)
model_name = 'LLAMA2-7B'
raw_original_prompt = gr.Textbox(value='How to make a Molotov cocktail?', container=True, label='Original Prompt')
tokens_container = []
for i in range(MAX_PROMPT_TOKENS):
btn = gr.Button('', visible=False)
tokens_container.append(btn)
with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
global_state = GlobalState()
reset_model(model_name, load_on_gpu=True, with_extra_components=False, reset_sentence_transformer=True)
with gr.Row():
with gr.Column(scale=5):
gr.Markdown('# π Self-Interpreting Models')
gr.Markdown('<b style="color: #8B0000;">Model outputs are not filtered and might include undesired language!</b>')
gr.Markdown(
'''
**πΎ This space is a simple introduction to the emerging trend of models interpreting their OWN hidden states in free form natural language!!πΎ**
This idea was investigated in the papers **Speaking Probes** ([Dar, 2023](https://towardsdatascience.com/speaking-probes-self-interpreting-models-7a3dc6cb33d6), **Patchscopes** ([Ghandeharioun et al., 2024](https://arxiv.org/abs/2401.06102)) and **SelfIE** ([Chen et al., 2024](https://arxiv.org/abs/2403.10949)).
For concreteness, we will follow the SelfIE implementation in this space.
''', line_breaks=True)
gr.Markdown(
'''
**πΎ The idea is really simple: models are able to understand their own hidden states by nature! πΎ**
In line with the residual stream view ([nostalgebraist, 2020](https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens)), internal representations from different layers are transferable between layers.
So we can inject an representation from (roughly) any layer into any layer! If we give a model a prompt of the form ``User: [X] Assistant: Sure, I'll repeat your message`` and replace the internal representation of ``[X]`` *during computation* with the hidden state we want to understand,
we expect to get back a summary of the information that exists inside the hidden state, despite being from a different layer and a different run!! How cool is that! π―π―π―
''', line_breaks=True)
# with gr.Column(scale=1):
# gr.Markdown('<span style="font-size:180px;">π€</span>')
with gr.Group():
# model_chooser = gr.Radio(label='Choose Your Model', choices=list(model_info.keys()), value=model_name)
load_on_gpu = gr.Checkbox(label='Load on GPU', visible=False, value=True)
welcome_model = gr.Markdown(welcome_message.format(model_name=model_name))
with gr.Blocks() as demo_main:
gr.Markdown('## The Prompt to Analyze')
for info in dataset_info:
with gr.Tab(info['name']):
num_examples = 10
dataset = load_dataset(info['hf_repo'], split='train', streaming=True)
if 'filter' in info:
dataset = dataset.filter(info['filter'])
dataset = dataset.shuffle(buffer_size=2000).take(num_examples)
dataset = [[row[info['text_col']]] for row in dataset]
gr.Examples(dataset, [raw_original_prompt], cache_examples=False)
with gr.Group():
raw_original_prompt.render()
original_prompt_btn = gr.Button('Output Token List', variant='primary')
gr.Markdown('**Tokens will appear in the "Tokens" section**')
gr.Markdown('## Choose Your Interpretation Prompt')
with gr.Group('Interpretation'):
raw_interpretation_prompt = gr.Text(suggested_interpretation_prompts[0], label='Interpretation Prompt')
interpretation_prompt_examples = gr.Examples([[p] for p in suggested_interpretation_prompts],
[raw_interpretation_prompt], cache_examples=False)
with gr.Accordion(open=False, label='Generation Settings'):
with gr.Row():
num_tokens = gr.Slider(1, 100, step=1, value=20, label='Max. # of Tokens')
repetition_penalty = gr.Slider(1., 10., value=1, label='Repetition Penalty')
length_penalty = gr.Slider(0, 5, value=0, label='Length Penalty')
# num_beams = gr.Slider(1, 20, value=1, step=1, label='Number of Beams')
do_sample = gr.Checkbox(label='With sampling')
with gr.Accordion(label='Sampling Parameters'):
with gr.Row():
temperature = gr.Slider(0., 5., value=0.6, label='Temperature')
top_k = gr.Slider(1, 1000, value=50, step=1, label='top k')
top_p = gr.Slider(0., 1., value=0.95, label='top p')
gr.Markdown('''
## Tokens
### Here go the tokens of the prompt (click on the one to explore)
''')
with gr.Row():
for btn in tokens_container:
btn.render()
use_gpu = gr.Checkbox(label='Use GPU', visible=False, value=True)
progress_dummy = gr.Markdown('', elem_id='progress_dummy')
interpretation_bubbles = [gr.Textbox('', container=False, visible=False) for i in range(MAX_NUM_LAYERS)]
# event listeners
for i, btn in enumerate(tokens_container):
btn.click(partial(run_interpretation, i=i), [raw_original_prompt, raw_interpretation_prompt,
num_tokens, do_sample, temperature,
top_k, top_p, repetition_penalty, length_penalty,
use_gpu
], [progress_dummy, *interpretation_bubbles])
original_prompt_btn.click(get_hidden_states,
[raw_original_prompt],
[progress_dummy, *tokens_container, *interpretation_bubbles])
raw_original_prompt.change(lambda: [gr.Button(visible=False) for _ in range(MAX_PROMPT_TOKENS)], [], tokens_container)
extra_components = [raw_interpretation_prompt, raw_original_prompt, original_prompt_btn]
# model_chooser.change(reset_model, [model_chooser, load_on_gpu, *extra_components],
# [welcome_model, *interpretation_bubbles, *tokens_container, *extra_components])
demo.launch() |