|
import os |
|
import gc |
|
from typing import Optional |
|
from dataclasses import dataclass |
|
from copy import deepcopy |
|
from functools import partial |
|
import spaces |
|
import gradio as gr |
|
import torch |
|
from datasets import load_dataset |
|
from ctransformers import AutoModelForCausalLM as CAutoModelForCausalLM |
|
from transformers import PreTrainedModel, PreTrainedTokenizer, AutoModelForCausalLM, AutoTokenizer |
|
from interpret import InterpretationPrompt |
|
|
|
MAX_PROMPT_TOKENS = 60 |
|
|
|
|
|
|
|
dataset_info = [ |
|
{'name': 'Commonsense', 'hf_repo': 'tau/commonsense_qa', 'text_col': 'question'}, |
|
{'name': 'Factual Recall', 'hf_repo': 'azhx/counterfact-filtered-gptj6b', 'text_col': 'subject+predicate', |
|
'filter': lambda x: x['label'] == 1}, |
|
|
|
{'name': 'Social Reasoning', 'hf_repo': 'ProlificAI/social-reasoning-rlhf', 'text_col': 'question'} |
|
] |
|
|
|
|
|
model_info = { |
|
'LLAMA2-7B': dict(model_path='meta-llama/Llama-2-7b-chat-hf', device_map='cpu', token=os.environ['hf_token'], |
|
original_prompt_template='<s>{prompt}', |
|
interpretation_prompt_template='<s>[INST] [X] [/INST] {prompt}', |
|
), |
|
|
|
'Gemma-2B': dict(model_path='google/gemma-2b', device_map='cpu', token=os.environ['hf_token'], |
|
original_prompt_template='<bos>{prompt}', |
|
interpretation_prompt_template='<bos>User: [X]\n\nAnswer: {prompt}', |
|
), |
|
|
|
'Mistral-7B Instruct': dict(model_path='mistralai/Mistral-7B-Instruct-v0.2', device_map='cpu', |
|
original_prompt_template='<s>{prompt}', |
|
interpretation_prompt_template='<s>[INST] [X] [/INST] {prompt}', |
|
), |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
suggested_interpretation_prompts = [ |
|
"Sure, here's a bullet list of the key words in your message:", |
|
"Sure, I'll summarize your message:", |
|
"Sure, here are the words in your message:", |
|
"Before responding, let me repeat the message you wrote:", |
|
"Let me repeat the message:" |
|
] |
|
|
|
|
|
@dataclass |
|
class GlobalState: |
|
tokenizer : Optional[PreTrainedTokenizer] = None |
|
model : Optional[PreTrainedModel] = None |
|
hidden_states : Optional[torch.Tensor] = None |
|
interpretation_prompt_template : str = '{prompt}' |
|
original_prompt_template : str = '{prompt}' |
|
|
|
|
|
|
|
@spaces.GPU |
|
def initialize_gpu(): |
|
pass |
|
|
|
|
|
def reset_model(model_name): |
|
|
|
model_args = deepcopy(model_info[model_name]) |
|
model_path = model_args.pop('model_path') |
|
global_state.original_prompt_template = model_args.pop('original_prompt_template') |
|
global_state.interpretation_prompt_template = model_args.pop('interpretation_prompt_template') |
|
tokenizer_path = model_args.pop('tokenizer') if 'tokenizer' in model_args else model_path |
|
use_ctransformers = model_args.pop('ctransformers', False) |
|
AutoModelClass = CAutoModelForCausalLM if use_ctransformers else AutoModelForCausalLM |
|
|
|
|
|
global_state.model, global_state.tokenizer, global_state.hidden_states = None, None, None |
|
gc.collect() |
|
global_state.model = AutoModelClass.from_pretrained(model_path, **model_args).cuda() |
|
global_state.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token']) |
|
gc.collect() |
|
|
|
|
|
def get_hidden_states(raw_original_prompt): |
|
model, tokenizer = global_state.model, global_state.tokenizer |
|
original_prompt = global_state.original_prompt_template.format(prompt=raw_original_prompt) |
|
model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device) |
|
tokens = tokenizer.batch_decode(model_inputs.input_ids[0]) |
|
outputs = model(**model_inputs, output_hidden_states=True, return_dict=True) |
|
hidden_states = torch.stack([h.squeeze(0).cpu().detach() for h in outputs.hidden_states], dim=0) |
|
token_btns = ([gr.Button(token, visible=True) for token in tokens] |
|
+ [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))]) |
|
progress_dummy_output = '' |
|
invisible_bubbles = [gr.Textbox('', visible=False) for i in range(len(interpretation_bubbles))] |
|
global_state.hidden_states = hidden_states |
|
return [progress_dummy_output, *token_btns, *invisible_bubbles] |
|
|
|
|
|
@spaces.GPU |
|
def run_interpretation(raw_interpretation_prompt, max_new_tokens, do_sample, |
|
temperature, top_k, top_p, repetition_penalty, length_penalty, i, |
|
num_beams=1): |
|
|
|
interpreted_vectors = global_state.hidden_states[:, i] |
|
length_penalty = -length_penalty |
|
|
|
|
|
generation_kwargs = { |
|
'max_new_tokens': int(max_new_tokens), |
|
'do_sample': do_sample, |
|
'temperature': temperature, |
|
'top_k': int(top_k), |
|
'top_p': top_p, |
|
'repetition_penalty': repetition_penalty, |
|
'length_penalty': length_penalty, |
|
'num_beams': int(num_beams) |
|
} |
|
|
|
|
|
interpretation_prompt = global_state.interpretation_prompt_template.format(prompt=raw_interpretation_prompt, repeat=5) |
|
interpretation_prompt = InterpretationPrompt(global_state.tokenizer, interpretation_prompt) |
|
|
|
|
|
|
|
generated = interpretation_prompt.generate(global_state.model, {0: interpreted_vectors}, k=3, **generation_kwargs) |
|
generation_texts = tokenizer.batch_decode(generated) |
|
progress_dummy_output = '' |
|
return ([progress_dummy_output] + |
|
[gr.Textbox(text.replace('\n', ' '), visible=True, container=False, label=f'Layer {i}') for text in generation_texts] |
|
) |
|
|
|
|
|
|
|
torch.set_grad_enabled(False) |
|
global_state = GlobalState() |
|
|
|
model_name = 'LLAMA2-7B' |
|
reset_model(model_name) |
|
original_prompt_raw = gr.Textbox(value='How to make a Molotov cocktail?', container=True, label='Original Prompt') |
|
tokens_container = [] |
|
|
|
for i in range(MAX_PROMPT_TOKENS): |
|
btn = gr.Button('', visible=False, elem_classes=['token_btn']) |
|
tokens_container.append(btn) |
|
|
|
with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo: |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=5): |
|
gr.Markdown('# π Self-Interpreting Models') |
|
|
|
gr.Markdown('<b style="color: #8B0000;">Model outputs are not filtered and might include undesired language!</b>') |
|
|
|
|
|
|
|
|
|
|
|
gr.Markdown( |
|
''' |
|
**πΎ This space is a simple introduction to the emerging trend of models interpreting their OWN hidden states in free form natural language!!πΎ** |
|
This idea was investigated in the paper **Patchscopes** ([Ghandeharioun et al., 2024](https://arxiv.org/abs/2401.06102)) and was further explored in **SelfIE** ([Chen et al., 2024](https://arxiv.org/abs/2403.10949)). |
|
An honorary mention of **Speaking Probes** ([Dar, 2023](https://towardsdatascience.com/speaking-probes-self-interpreting-models-7a3dc6cb33d6) - my own work π₯³) which was less mature but had the same idea in mind. |
|
We will follow the SelfIE implementation in this space for concreteness. Patchscopes are so general that they encompass many other interpretation techniques too!!! |
|
''', line_breaks=True) |
|
|
|
|
|
|
|
|
|
gr.Markdown( |
|
''' |
|
**πΎ The idea is really simple: models are able to understand their own hidden states by nature! πΎ** |
|
In line with the residual stream view ([nostalgebraist, 2020](https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens)), internal representations from different layers are transferable between layers. |
|
So we can inject an representation from (roughly) any layer into any layer! If we give a model a prompt of the form ``User: [X] Assistant: Sure'll I'll repeat your message`` and replace the internal representation of ``[X]`` *during computation* with the hidden state we want to understand, |
|
we expect to get back a summary of the information that exists inside the hidden state, despite being from a different layer and a different run!! How cool is that! π―π―π― |
|
''', line_breaks=True) |
|
|
|
|
|
|
|
|
|
with gr.Group(): |
|
model_chooser = gr.Radio(choices=list(model_info.keys()), value=model_name) |
|
|
|
gr.Markdown('## Choose Your Interpretation Prompt') |
|
with gr.Group('Interpretation'): |
|
interpretation_prompt = gr.Text(suggested_interpretation_prompts[0], label='Interpretation Prompt') |
|
gr.Examples([[p] for p in suggested_interpretation_prompts], [interpretation_prompt], cache_examples=False) |
|
|
|
|
|
gr.Markdown('## The Prompt to Analyze') |
|
for info in dataset_info: |
|
with gr.Tab(info['name']): |
|
num_examples = 10 |
|
dataset = load_dataset(info['hf_repo'], split='train', streaming=True) |
|
if 'filter' in info: |
|
dataset = dataset.filter(info['filter']) |
|
dataset = dataset.shuffle(buffer_size=2000).take(num_examples) |
|
dataset = [[row[info['text_col']]] for row in dataset] |
|
gr.Examples(dataset, [global_state, original_prompt_raw], cache_examples=False) |
|
|
|
with gr.Group(): |
|
original_prompt_raw.render() |
|
original_prompt_btn = gr.Button('Output Token List', variant='primary') |
|
|
|
gr.Markdown('### Here go the tokens of the prompt (click on the one to explore)') |
|
|
|
with gr.Row(): |
|
for btn in tokens_container: |
|
btn.render() |
|
|
|
|
|
with gr.Accordion(open=False, label='Generation Settings'): |
|
with gr.Row(): |
|
num_tokens = gr.Slider(1, 100, step=1, value=20, label='Max. # of Tokens') |
|
repetition_penalty = gr.Slider(1., 10., value=1, label='Repetition Penalty') |
|
length_penalty = gr.Slider(0, 5, value=0, label='Length Penalty') |
|
|
|
do_sample = gr.Checkbox(label='With sampling') |
|
with gr.Accordion(label='Sampling Parameters'): |
|
with gr.Row(): |
|
temperature = gr.Slider(0., 5., value=0.6, label='Temperature') |
|
top_k = gr.Slider(1, 1000, value=50, step=1, label='top k') |
|
top_p = gr.Slider(0., 1., value=0.95, label='top p') |
|
|
|
progress_dummy = gr.Markdown('', elem_id='progress_dummy') |
|
interpretation_bubbles = [gr.Textbox('', container=False, visible=False, |
|
elem_classes=['bubble', 'even_bubble' if i % 2 == 0 else 'odd_bubble'] |
|
) for i in range(model.config.num_hidden_layers)] |
|
|
|
|
|
|
|
model_chooser.change(reset_new_model, [model_chooser], []) |
|
|
|
for i, btn in enumerate(tokens_container): |
|
btn.click(partial(run_interpretation, i=i), [global_state, interpretation_prompt, |
|
num_tokens, do_sample, temperature, |
|
top_k, top_p, repetition_penalty, length_penalty, |
|
], [progress_dummy, *interpretation_bubbles]) |
|
|
|
original_prompt_btn.click(get_hidden_states, |
|
[original_prompt_raw], |
|
[progress_dummy, *tokens_container, *interpretation_bubbles]) |
|
original_prompt_raw.change(lambda: [gr.Button(visible=False) for _ in range(MAX_PROMPT_TOKENS)], [], tokens_container) |
|
demo.launch() |