selfie / app.py
dar-tau's picture
Create app.py
4d6d2dc verified
raw
history blame
8.43 kB
import os
import spaces
from copy import deepcopy
import gradio as gr
import torch
from ctransformers import AutoModelForCausalLM as CAutoModelForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer
from .interpret import InterpretationPrompt
## info
model_info = {
'meta-llama/Llama-2-7b-chat-hf': dict(device_map='cpu', token=os.environ['hf_token'],
original_prompt_template='<s>[INST] {prompt}',
interpretation_prompt_template='<s>[INST] [X] [/INST] {prompt}',
), # , load_in_8bit=True
'google/gemma-2b': dict(device_map='cpu', token=os.environ['hf_token'],
original_prompt_template='<bos> {prompt}',
interpretation_prompt_template='<bos>User: [X]\n\nAnswer: {prompt}',
),
'mistralai/Mistral-7B-Instruct-v0.2': dict(device_map='cpu',
original_prompt_template='<s>[INST] {prompt} [/INST]',
interpretation_prompt_template='<s>[INST] [X] [/INST] {prompt}',
),
'TheBloke/Mistral-7B-Instruct-v0.2-GGUF': dict(model_file='mistral-7b-instruct-v0.2.Q5_K_S.gguf',
tokenizer='mistralai/Mistral-7B-Instruct-v0.2',
model_type='llama', hf=True, ctransformers=True,
original_prompt_template='<s>[INST] {prompt} [/INST]',
interpretation_prompt_template='<s>[INST] [X] [/INST] {prompt}',
)
}
suggested_interpretation_prompts = ["Before responding, let me repeat the message you wrote:",
"Let me repeat the message:", "Sure, I'll summarize your message:"]
## functions
def get_hidden_states(raw_original_prompt):
original_prompt = original_prompt_template.format(prompt=raw_original_prompt)
model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
tokens = tokenizer.batch_decode(model_inputs.input_ids)
outputs = model(**model_inputs, output_hidden_states=True, return_dict=True)
hidden_states = torch.stack([h.squeeze(0).cpu().detach() for h in outputs.hidden_states], dim=0)
with gr.Row() as tokens_container:
for token in tokens:
gr.Button(token)
return tokens_container
def run_model(raw_original_prompt, raw_interpretation_prompt, max_new_tokens, do_sample,
temperature, top_k, top_p, repetition_penalty, length_penalty, num_beams=1):
length_penalty = -length_penalty # unintuitively, length_penalty > 0 will make sequences longer, so we negate it
# generation parameters
generation_kwargs = {
'max_new_tokens': int(max_new_tokens),
'do_sample': do_sample,
'temperature': temperature,
'top_k': int(top_k),
'top_p': top_p,
'repetition_penalty': repetition_penalty,
'length_penalty': length_penalty,
'num_beams': int(num_beams)
}
# create an InterpretationPrompt object from raw_interpretation_prompt (after putting it in the right template)
interpretation_prompt = interpretation_prompt_template.format(prompt=raw_interpretation_prompt)
interpretation_prompt = InterpretationPrompt(tokenizer, interpretation_prompt)
# compute the hidden stated from the original prompt (after putting it in the right template)
original_prompt = original_prompt_template.format(prompt=raw_original_prompt)
model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
outputs = model(**model_inputs, output_hidden_states=True, return_dict=True)
hidden_states = torch.stack([h.squeeze(0).cpu().detach() for h in outputs.hidden_states], dim=0)
# generate the interpretations
generated = interpretation_prompt.generate(model, {0: hidden_states[:, -1]}, k=3, **generation_kwargs)
generation_texts = tokenizer.batch_decode(generated)
# tokens = [x.lstrip('▁') for x in tokenizer.tokenize(text)]
return generation_texts
## main
torch.set_grad_enabled(False)
model_name = 'meta-llama/Llama-2-7b-chat-hf' # 'mistralai/Mistral-7B-Instruct-v0.2' #
# extract model info
model_args = deepcopy(model_info[model_name])
original_prompt_template = model_args.pop('original_prompt_template')
interpretation_prompt_template = model_args.pop('interpretation_prompt_template')
tokenizer_name = model_args.pop('tokenizer') if 'tokenizer' in model_args else model_name
use_ctransformers = model_args.pop('ctransformers', False)
AutoModelClass = CAutoModelForCausalLM if use_ctransformers else AutoModelForCausalLM
# get model
model = AutoModelClass.from_pretrained(model_name, **model_args)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, token=os.environ['hf_token'])
with gr.Blocks(theme=gr.themes.Default()) as demo:
with gr.Row():
with gr.Column(scale=5):
gr.Markdown('''
# 😎 Self-Interpreting Models 😎
πŸ‘Ύ **This space follows the emerging trend of models interpreting their _own hidden states_ in free form natural language**!! πŸ‘Ύ
This idea was explored in the paper **Patchscopes** ([Ghandeharioun et al., 2024](https://arxiv.org/abs/2401.06102)) and was later investigated further in **SelfIE** ([Chen et al., 2024](https://arxiv.org/abs/2403.10949)).
An honorary mention for **Speaking Probes** ([Dar, 2023](https://towardsdatascience.com/speaking-probes-self-interpreting-models-7a3dc6cb33d6) -- my post!! πŸ₯³) which was a less mature approach but with the same idea in mind.
We follow the SelfIE implementation in this space for concreteness. Patchscopes are so general that they encompass many other interpretation techniques too!!!
πŸ‘Ύ **The idea is really simple: models are able to understand their own hidden states by nature!** πŸ‘Ύ
If I give a model a prompt of the form ``User: [X] Assistant: Sure'll I'll repeat your message`` and replace ``[X]`` *during computation* with the hidden state we want to understand,
we hope to get back a summary of the information that exists inside the hidden state, because it is encoded in a latent space the model uses itself!! How cool is that! 😯😯😯
''', line_breaks=True)
with gr.Column(scale=1):
gr.Markdown('<span style="font-size:180px;">πŸ€”</span>')
with gr.Group():
text = gr.Textbox(value='How to make a Molotov cocktail', container=True, label='Original Prompt')
btn = gr.Button('Compute', variant='primary')
with gr.Accordion(open=False, label='Settings'):
with gr.Row():
num_tokens = gr.Slider(1, 100, step=1, value=20, label='Max. # of Tokens')
repetition_penalty = gr.Slider(1., 10., value=1, label='Repetition Penalty')
length_penalty = gr.Slider(0, 5, value=0, label='Length Penalty')
# num_beams = gr.Slider(1, 20, value=1, step=1, label='Number of Beams')
do_sample = gr.Checkbox(label='With sampling')
with gr.Accordion(label='Sampling Parameters'):
with gr.Row():
temperature = gr.Slider(0., 5., value=0.6, label='Temperature')
top_k = gr.Slider(1, 1000, value=50, step=1, label='top k')
top_p = gr.Slider(0., 1., value=0.95, label='top p')
with gr.Group('Interpretation'):
interpretation_prompt = gr.Text(suggested_interpretation_prompts[0], label='Interpretation Prompt')
with gr.Group('Output'):
with gr.Row() as tokens_container:
pass
with gr.Column() as interpretations_container:
pass
btn.click(get_hidden_states, [text], [tokens_container])
# btn.click(run_model,
# [text, interpretation_prompt, num_tokens, do_sample, temperature,
# top_k, top_p, repetition_penalty, length_penalty],
# [tokens_container])
demo.launch()