selfie / interpret.py
dar-tau's picture
Update interpret.py
3f8ef2d verified
import os
import re
from collections import defaultdict
import numpy as np
import torch
from torch import nn
from contextlib import AbstractContextManager
# helper functions
def item(x):
return np.array(x).item()
def _prompt_to_parts(prompt, repeat=5):
# In order to allow easy formatting for prompts, we take string prompts
# in the format "[INST] [X] [/INST] Sure, I'll summarize this"
# and split them into a list of strings ["[INST]", 0, 0, 0, 0, 0, " [/INST] Sure, I'll summarize this"].
# Notice how each instance of [X] is replaced by multiple 0 placeholders (according to `~repeat`).
# This is in line with the SELFIE paper, where each interpreted token is inserted 5 times, probably to make
# the interpretation less likely to avoid it.
split_prompt = re.split(r' *\[X\]', prompt)
parts = []
for i in range(len(split_prompt)):
cur_part = split_prompt[i]
if cur_part != '':
# if we have multiple [X] in procession, there will be a '' between them in split_prompt
parts.append(cur_part)
if i < len(split_prompt) - 1:
parts.extend([0] * repeat)
print('Prompt parts:', parts)
return parts
class Hook(AbstractContextManager):
# Hook could be easily absorbed into SubstitutionHook instead, but I like it better to have them both.
# Seems like the right way from an aesthetic point of view.
def __init__(self, module, fn):
self.registered_hook = module.register_forward_hook(fn)
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
self.close()
def close(self):
self.registered_hook.remove()
class SubstitutionHook(Hook):
# This is where the substitution takes place, and it will be used by InterpretationPrompt later.
def __init__(self, module, positions_dict, values_dict):
assert set(positions_dict.keys()) == set(values_dict.keys())
keys = positions_dict.keys()
def fn(module, input, output):
device = output[0].device
dtype = output[0].dtype
for key in keys:
num_positions = len(positions_dict[key])
values = values_dict[key].unsqueeze(1).expand(-1, num_positions, -1) # batch_size x num_positions x hidden_dim
positions = positions_dict[key]
print(f'{positions=} {values.shape=} {output[0].shape=}')
output[0][:, positions, :] = values.to(dtype).to(device)
self.registered_hook.remove() # in generation with use_cache=True, after the first step the rest of the steps are one at a time
return output
self.registered_hook = module.register_forward_hook(fn)
# functions
class InterpretationPrompt:
def __init__(self, tokenizer, prompt, placeholder_token=' '):
prompt_parts = _prompt_to_parts(prompt)
if placeholder_token is None:
placeholder_token_id = tokenizer.eos_token_id
else:
placeholder_token_id = item(tokenizer.encode(placeholder_token, add_special_tokens=False))
assert placeholder_token_id != tokenizer.eos_token_id
self.tokens = []
self.placeholders = defaultdict(list)
for part in prompt_parts:
if type(part) == str:
self.tokens.extend(tokenizer.encode(part, add_special_tokens=False))
elif type(part) == int:
self.placeholders[part].append(len(self.tokens))
self.tokens.append(placeholder_token_id)
else:
raise NotImplementedError
def generate(self, model, embeds, k, layers_format='model.layers.{k}', **generation_kwargs):
num_seqs = len(embeds[0]) # assumes the placeholder 0 exists
tokens_batch = torch.tensor([self.tokens[:] for _ in range(num_seqs)]).to(model.device)
module = model.get_submodule(layers_format.format(k=k))
with SubstitutionHook(module, positions_dict=self.placeholders, values_dict=embeds):
generated = model.generate(tokens_batch, **generation_kwargs)
return generated