|
import os |
|
import spaces |
|
from copy import deepcopy |
|
import gradio as gr |
|
import torch |
|
from ctransformers import AutoModelForCausalLM as CAutoModelForCausalLM |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from .interpret import InterpretationPrompt |
|
|
|
|
|
|
|
model_info = { |
|
'meta-llama/Llama-2-7b-chat-hf': dict(device_map='cpu', token=os.environ['hf_token'], |
|
original_prompt_template='<s>[INST] {prompt}', |
|
interpretation_prompt_template='<s>[INST] [X] [/INST] {prompt}', |
|
), |
|
|
|
'google/gemma-2b': dict(device_map='cpu', token=os.environ['hf_token'], |
|
original_prompt_template='<bos> {prompt}', |
|
interpretation_prompt_template='<bos>User: [X]\n\nAnswer: {prompt}', |
|
), |
|
|
|
'mistralai/Mistral-7B-Instruct-v0.2': dict(device_map='cpu', |
|
original_prompt_template='<s>[INST] {prompt} [/INST]', |
|
interpretation_prompt_template='<s>[INST] [X] [/INST] {prompt}', |
|
), |
|
|
|
'TheBloke/Mistral-7B-Instruct-v0.2-GGUF': dict(model_file='mistral-7b-instruct-v0.2.Q5_K_S.gguf', |
|
tokenizer='mistralai/Mistral-7B-Instruct-v0.2', |
|
model_type='llama', hf=True, ctransformers=True, |
|
original_prompt_template='<s>[INST] {prompt} [/INST]', |
|
interpretation_prompt_template='<s>[INST] [X] [/INST] {prompt}', |
|
) |
|
} |
|
|
|
|
|
suggested_interpretation_prompts = ["Before responding, let me repeat the message you wrote:", |
|
"Let me repeat the message:", "Sure, I'll summarize your message:"] |
|
|
|
|
|
|
|
def get_hidden_states(raw_original_prompt): |
|
original_prompt = original_prompt_template.format(prompt=raw_original_prompt) |
|
model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device) |
|
tokens = tokenizer.batch_decode(model_inputs.input_ids) |
|
outputs = model(**model_inputs, output_hidden_states=True, return_dict=True) |
|
hidden_states = torch.stack([h.squeeze(0).cpu().detach() for h in outputs.hidden_states], dim=0) |
|
with gr.Row() as tokens_container: |
|
for token in tokens: |
|
gr.Button(token) |
|
return tokens_container |
|
|
|
|
|
def run_model(raw_original_prompt, raw_interpretation_prompt, max_new_tokens, do_sample, |
|
temperature, top_k, top_p, repetition_penalty, length_penalty, num_beams=1): |
|
|
|
length_penalty = -length_penalty |
|
|
|
|
|
generation_kwargs = { |
|
'max_new_tokens': int(max_new_tokens), |
|
'do_sample': do_sample, |
|
'temperature': temperature, |
|
'top_k': int(top_k), |
|
'top_p': top_p, |
|
'repetition_penalty': repetition_penalty, |
|
'length_penalty': length_penalty, |
|
'num_beams': int(num_beams) |
|
} |
|
|
|
|
|
interpretation_prompt = interpretation_prompt_template.format(prompt=raw_interpretation_prompt) |
|
interpretation_prompt = InterpretationPrompt(tokenizer, interpretation_prompt) |
|
|
|
|
|
original_prompt = original_prompt_template.format(prompt=raw_original_prompt) |
|
model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device) |
|
outputs = model(**model_inputs, output_hidden_states=True, return_dict=True) |
|
hidden_states = torch.stack([h.squeeze(0).cpu().detach() for h in outputs.hidden_states], dim=0) |
|
|
|
|
|
generated = interpretation_prompt.generate(model, {0: hidden_states[:, -1]}, k=3, **generation_kwargs) |
|
generation_texts = tokenizer.batch_decode(generated) |
|
|
|
return generation_texts |
|
|
|
|
|
|
|
torch.set_grad_enabled(False) |
|
model_name = 'meta-llama/Llama-2-7b-chat-hf' |
|
|
|
|
|
model_args = deepcopy(model_info[model_name]) |
|
original_prompt_template = model_args.pop('original_prompt_template') |
|
interpretation_prompt_template = model_args.pop('interpretation_prompt_template') |
|
tokenizer_name = model_args.pop('tokenizer') if 'tokenizer' in model_args else model_name |
|
use_ctransformers = model_args.pop('ctransformers', False) |
|
AutoModelClass = CAutoModelForCausalLM if use_ctransformers else AutoModelForCausalLM |
|
|
|
|
|
model = AutoModelClass.from_pretrained(model_name, **model_args) |
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, token=os.environ['hf_token']) |
|
|
|
with gr.Blocks(theme=gr.themes.Default()) as demo: |
|
with gr.Row(): |
|
with gr.Column(scale=5): |
|
gr.Markdown(''' |
|
# π Self-Interpreting Models π |
|
|
|
πΎ **This space follows the emerging trend of models interpreting their _own hidden states_ in free form natural language**!! πΎ |
|
This idea was explored in the paper **Patchscopes** ([Ghandeharioun et al., 2024](https://arxiv.org/abs/2401.06102)) and was later investigated further in **SelfIE** ([Chen et al., 2024](https://arxiv.org/abs/2403.10949)). |
|
An honorary mention for **Speaking Probes** ([Dar, 2023](https://towardsdatascience.com/speaking-probes-self-interpreting-models-7a3dc6cb33d6) -- my post!! π₯³) which was a less mature approach but with the same idea in mind. |
|
We follow the SelfIE implementation in this space for concreteness. Patchscopes are so general that they encompass many other interpretation techniques too!!! |
|
|
|
πΎ **The idea is really simple: models are able to understand their own hidden states by nature!** πΎ |
|
If I give a model a prompt of the form ``User: [X] Assistant: Sure'll I'll repeat your message`` and replace ``[X]`` *during computation* with the hidden state we want to understand, |
|
we hope to get back a summary of the information that exists inside the hidden state, because it is encoded in a latent space the model uses itself!! How cool is that! π―π―π― |
|
''', line_breaks=True) |
|
with gr.Column(scale=1): |
|
gr.Markdown('<span style="font-size:180px;">π€</span>') |
|
|
|
with gr.Group(): |
|
text = gr.Textbox(value='How to make a Molotov cocktail', container=True, label='Original Prompt') |
|
btn = gr.Button('Compute', variant='primary') |
|
|
|
with gr.Accordion(open=False, label='Settings'): |
|
with gr.Row(): |
|
num_tokens = gr.Slider(1, 100, step=1, value=20, label='Max. # of Tokens') |
|
repetition_penalty = gr.Slider(1., 10., value=1, label='Repetition Penalty') |
|
length_penalty = gr.Slider(0, 5, value=0, label='Length Penalty') |
|
|
|
do_sample = gr.Checkbox(label='With sampling') |
|
with gr.Accordion(label='Sampling Parameters'): |
|
with gr.Row(): |
|
temperature = gr.Slider(0., 5., value=0.6, label='Temperature') |
|
top_k = gr.Slider(1, 1000, value=50, step=1, label='top k') |
|
top_p = gr.Slider(0., 1., value=0.95, label='top p') |
|
|
|
with gr.Group('Interpretation'): |
|
interpretation_prompt = gr.Text(suggested_interpretation_prompts[0], label='Interpretation Prompt') |
|
|
|
with gr.Group('Output'): |
|
with gr.Row() as tokens_container: |
|
pass |
|
with gr.Column() as interpretations_container: |
|
pass |
|
|
|
btn.click(get_hidden_states, [text], [tokens_container]) |
|
|
|
|
|
|
|
|
|
|
|
demo.launch() |