selfie / app.py
dar-tau's picture
Update app.py
868605b verified
raw
history blame
12.5 kB
import os
from copy import deepcopy
from functools import partial
import spaces
import gradio as gr
import torch
from datasets import load_dataset
from ctransformers import AutoModelForCausalLM as CAutoModelForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer
from interpret import InterpretationPrompt
MAX_PROMPT_TOKENS = 60
## info
dataset_info = [{'name': 'Commonsense', 'hf_repo': 'tau/commonsense_qa', 'text_col': 'question'},
{'name': 'Factual Recall', 'hf_repo': 'azhx/counterfact-filtered-gptj6b', 'text_col': 'subject+predicate',
'filter': lambda x: x['label'] == 1},
]
model_info = {
'LLAMA2-7B': dict(model_path='meta-llama/Llama-2-7b-chat-hf', device_map='cpu', token=os.environ['hf_token'],
original_prompt_template='<s>[INST] {prompt} [/INST]',
interpretation_prompt_template='<s>[INST] [X] [/INST] {prompt}',
), # , load_in_8bit=True
'Gemma-2B': dict(model_path='google/gemma-2b', device_map='cpu', token=os.environ['hf_token'],
original_prompt_template='<bos> {prompt}',
interpretation_prompt_template='<bos>User: [X]\n\nAnswer: {prompt}',
),
'Mistral-7B Instruct': dict(model_path='mistralai/Mistral-7B-Instruct-v0.2', device_map='cpu',
original_prompt_template='<s>[INST] {prompt} [/INST]',
interpretation_prompt_template='<s>[INST] [X] [/INST] {prompt}',
),
# 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF': dict(model_file='mistral-7b-instruct-v0.2.Q5_K_S.gguf',
# tokenizer='mistralai/Mistral-7B-Instruct-v0.2',
# model_type='llama', hf=True, ctransformers=True,
# original_prompt_template='<s>[INST] {prompt} [/INST]',
# interpretation_prompt_template='<s>[INST] [X] [/INST] {prompt}',
# )
}
suggested_interpretation_prompts = ["Before responding, let me repeat the message you wrote:",
"Let me repeat the message:", "Sure, I'll summarize your message:"]
## functions
@spaces.GPU
def initialize_gpu():
pass
def get_hidden_states(raw_original_prompt, progress=gr.Progress()):
original_prompt = original_prompt_template.format(prompt=raw_original_prompt)
model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
tokens = tokenizer.batch_decode(model_inputs.input_ids[0])
outputs = model(**model_inputs, output_hidden_states=True, return_dict=True)
hidden_states = torch.stack([h.squeeze(0).cpu().detach() for h in outputs.hidden_states], dim=0)
token_btns = ([gr.Button(token, visible=True) for token in tokens]
+ [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))])
progress_dummy_output = ''
return [progress_dummy_output, hidden_states, *token_btns]
@spaces.GPU
def generate_interpretation_gpu(interpret_prompt, *args, **kwargs):
return interpret_prompt.generate(*args, **kwargs)
def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens, do_sample,
temperature, top_k, top_p, repetition_penalty, length_penalty, use_gpu, i,
num_beams=1):
interpreted_vectors = global_state[:, i]
length_penalty = -length_penalty # unintuitively, length_penalty > 0 will make sequences longer, so we negate it
# generation parameters
generation_kwargs = {
'max_new_tokens': int(max_new_tokens),
'do_sample': do_sample,
'temperature': temperature,
'top_k': int(top_k),
'top_p': top_p,
'repetition_penalty': repetition_penalty,
'length_penalty': length_penalty,
'num_beams': int(num_beams)
}
# create an InterpretationPrompt object from raw_interpretation_prompt (after putting it in the right template)
interpretation_prompt = interpretation_prompt_template.format(prompt=raw_interpretation_prompt, repeat=5)
interpretation_prompt = InterpretationPrompt(tokenizer, interpretation_prompt)
# generate the interpretations
generate = generate_interpretation_gpu if use_gpu else lambda interpretation_prompt, *args, **kwargs: interpretation_prompt.generate(*args, **kwargs)
generated = generate(interpretation_prompt, model, {0: interpreted_vectors}, k=3, **generation_kwargs)
generation_texts = tokenizer.batch_decode(generated)
progress_dummy_output = ''
return ([progress_dummy_output] +
[gr.Textbox(text.replace('\n', ' '), visible=True, container=False) for text in generation_texts]
)
## main
torch.set_grad_enabled(False)
model_name = 'LLAMA2-7B'
# extract model info
model_args = deepcopy(model_info[model_name])
model_path = model_args.pop('model_path')
original_prompt_template = model_args.pop('original_prompt_template')
interpretation_prompt_template = model_args.pop('interpretation_prompt_template')
tokenizer_path = model_args.pop('tokenizer') if 'tokenizer' in model_args else model_path
use_ctransformers = model_args.pop('ctransformers', False)
AutoModelClass = CAutoModelForCausalLM if use_ctransformers else AutoModelForCausalLM
# get model
model = AutoModelClass.from_pretrained(model_path, **model_args).cuda()
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
# demo
json_output = gr.JSON()
css = '''
.bubble {
border: none
border-radius: 10px;
padding: 10px;
margin-top: 15px;
margin-left: 5%;
width: 70%;
box-shadow: 2px 2px 4px rgba(0, 0, 0, 0.3);
}
.even_bubble{
background: pink;
}
.odd_bubble{
background: skyblue;
}
.bubble textarea {
border: none;
box-shadow: none;
background: inherit;
resize: none;
}
.explanation_accordion .svelte-s1r2yt{
font-weight: bold;
text-align: start;
}
'''
# '''
# .token_btn{
# background-color: none;
# background: none;
# border: none;
# padding: 0;
# font: inherit;
# cursor: pointer;
# color: blue; /* default text color */
# font-weight: bold;
# }
# .token_btn:hover {
# color: red;
# }
# '''
original_prompt_raw = gr.Textbox(value='How to make a Molotov cocktail?', container=True, label='Original Prompt')
with gr.Blocks(theme=gr.themes.Default(), css=css) as demo:
global_state = gr.State([])
with gr.Row():
with gr.Column(scale=5):
gr.Markdown('# 😎 Self-Interpreting Models')
gr.Markdown(
'**πŸ‘Ύ This space is a simple introduction to the emerging trend of models interpreting their OWN hidden states in free form natural language!!πŸ‘Ύ**',
# elem_classes=['explanation_accordion']
)
gr.Markdown(
'''This idea was investigated in the paper **Patchscopes** ([Ghandeharioun et al., 2024](https://arxiv.org/abs/2401.06102)) and was further explored in **SelfIE** ([Chen et al., 2024](https://arxiv.org/abs/2403.10949)).
An honorary mention of **Speaking Probes** ([Dar, 2023](https://towardsdatascience.com/speaking-probes-self-interpreting-models-7a3dc6cb33d6) - my own work πŸ₯³) which was less mature but had the same idea in mind.
We will follow the SelfIE implementation in this space for concreteness. Patchscopes are so general that they encompass many other interpretation techniques too!!!
''', line_breaks=True)
# gr.Markdown('**πŸ‘Ύ The idea is really simple: models are able to understand their own hidden states by nature! πŸ‘Ύ**',
# # elem_classes=['explanation_accordion']
# )
gr.Markdown(
'''
**πŸ‘Ύ The idea is really simple: models are able to understand their own hidden states by nature! πŸ‘Ύ**
According to the residual stream view ([nostalgebraist, 2020](https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens)), internal representations from different layers are transferable between layers.
So we can inject an representation from (roughly) any layer to any layer! If I give a model a prompt of the form ``User: [X] Assistant: Sure'll I'll repeat your message`` and replace the internal representation of ``[X]`` *during computation* with the hidden state we want to understand,
we expect to get back a summary of the information that exists inside the hidden state. Since the model uses a roughly common latent space, it can understand representations from different layers and different runs!! How cool is that! 😯😯😯
''', line_breaks=True)
# with gr.Column(scale=1):
# gr.Markdown('<span style="font-size:180px;">πŸ€”</span>')
with gr.Group('Interpretation'):
interpretation_prompt = gr.Text(suggested_interpretation_prompts[0], label='Interpretation Prompt')
gr.Markdown('''
Here are some examples of prompts we can analyze their internal representations:
''')
# for info in dataset_info:
# with gr.Tab(info['name']):
# num_examples = 10
# dataset = load_dataset(info['hf_repo'], split='train', streaming=True)
# if 'filter' in info:
# dataset = dataset.filter(info['filter'])
# dataset = dataset.shuffle(buffer_size=2000).take(num_examples)
# dataset = [[row[info['text_col']]] for row in dataset]
# gr.Examples(dataset, [original_prompt_raw])
with gr.Group():
original_prompt_raw.render()
original_prompt_btn = gr.Button('Compute', variant='primary')
tokens_container = []
with gr.Row():
for i in range(MAX_PROMPT_TOKENS):
btn = gr.Button('', visible=False, elem_classes=['token_btn'])
tokens_container.append(btn)
use_gpu = False # gr.Checkbox(value=False, label='Use GPU')
progress_dummy = gr.Markdown('', elem_id='progress_dummy')
interpretation_bubbles = [gr.Textbox('', label=f'Layer {i}', container=False, visible=False, elem_classes=['bubble',
'even_bubble' if i % 2 == 0 else 'odd_bubble'])
for i in range(model.config.num_hidden_layers)]
with gr.Accordion(open=False, label='Settings'):
with gr.Row():
num_tokens = gr.Slider(1, 100, step=1, value=20, label='Max. # of Tokens')
repetition_penalty = gr.Slider(1., 10., value=1, label='Repetition Penalty')
length_penalty = gr.Slider(0, 5, value=0, label='Length Penalty')
# num_beams = gr.Slider(1, 20, value=1, step=1, label='Number of Beams')
do_sample = gr.Checkbox(label='With sampling')
with gr.Accordion(label='Sampling Parameters'):
with gr.Row():
temperature = gr.Slider(0., 5., value=0.6, label='Temperature')
top_k = gr.Slider(1, 1000, value=50, step=1, label='top k')
top_p = gr.Slider(0., 1., value=0.95, label='top p')
# with gr.Group():
# with gr.Row():
# for txt in model_info.keys():
# btn = gr.Button(txt)
# model_btns.append(btn)
# for btn in model_btns:
# btn.click(reset_new_model, [global_state])
# event listeners
for i, btn in enumerate(tokens_container):
btn.click(partial(run_interpretation, i=i, use_gpu=use_gpu), [global_state, interpretation_prompt,
num_tokens, do_sample, temperature,
top_k, top_p, repetition_penalty, length_penalty,
], [progress_dummy, *interpretation_bubbles])
original_prompt_btn.click(get_hidden_states,
[original_prompt_raw],
[progress_dummy, global_state, *tokens_container])
demo.launch()