selfie / interpret.py
dar-tau's picture
Create interpret.py
9dd96f2 verified
raw
history blame
4.17 kB
import os
import re
from collections import defaultdict
import numpy as np
import torch
from torch import nn
from contextlib import AbstractContextManager
# helper functions
def item(x):
return np.array(x).item()
def _prompt_to_parts(prompt, repeat=5):
# In order to allow easy formatting for prompts, we take string prompts
# in the format "[INST] [X] [/INST] Sure, I'll summarize this"
# and split them into a list of strings ["[INST]", 0, 0, 0, 0, 0, " [/INST] Sure, I'll summarize this"].
# Notice how each instance of [X] is replaced by multiple 0 placeholders (according to `~repeat`).
# This is in line with the SELFIE paper, where each interpreted token is inserted 5 times, probably to make
# the interpretation less likely to avoid it.
split_prompt = re.split(r' *\[X\]', prompt)
parts = []
for i in range(len(split_prompt)):
cur_part = split_prompt[i]
if cur_part != '':
# if we have multiple [X] in procession, there will be a '' between them in split_prompt
parts.append(cur_part)
if i < len(split_prompt) - 1:
parts.extend([0] * repeat)
print('Prompt parts:', parts)
return parts
class Hook(AbstractContextManager):
# Hook could be easily absorbed into SubstitutionHook instead, but I like it better to have them both.
# Seems like the right way from an aesthetic point of view.
def __init__(self, module, fn):
self.registered_hook = module.register_forward_hook(fn)
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
self.close()
def close(self):
self.registered_hook.remove()
class SubstitutionHook(Hook):
# This is where the substitution takes place, and it will be used by InterpretationPrompt later.
def __init__(self, module, positions_dict, values_dict):
assert set(positions_dict.keys()) == set(values_dict.keys())
keys = positions_dict.keys()
def fn(module, input, output):
device = output[0].device
dtype = output[0].dtype
for key in keys:
num_positions = len(positions_dict[key])
values = values_dict[key].unsqueeze(1).expand(-1, num_positions, -1) # batch_size x num_positions x hidden_dim
positions = positions_dict[key]
print(f'{positions=} {values.shape=} {output[0].shape=}')
output[0][:, positions, :] = values.to(dtype).to(device)
self.registered_hook.remove() # in generation with use_cache=True, after the first step the rest of the steps are one at a time
return output
self.registered_hook = module.register_forward_hook(fn)
# functions
class InterpretationPrompt:
def __init__(self, tokenizer, prompt, placeholder_token=' '):
prompt_parts = _prompt_to_parts(prompt)
if placeholder_token is None:
placeholder_token_id = tokenizer.eos_token_id
else:
placeholder_token_id = item(tokenizer.encode(placeholder_token, add_special_tokens=False))
assert placeholder_token_id != tokenizer.eos_token_id
self.tokens = []
self.placeholders = defaultdict(list)
for part in prompt_parts:
if type(part) == str:
self.tokens.extend(tokenizer.encode(part, add_special_tokens=False))
elif type(part) == int:
self.placeholders[part].append(len(self.tokens))
self.tokens.append(placeholder_token_id)
else:
raise NotImplementedError
def generate(self, model, embeds, k, layer_format='model.layers.{k}', **generation_kwargs):
num_seqs = len(embeds[0]) # assumes the placeholder 0 exists
tokens_batch = torch.tensor([self.tokens[:] for _ in range(num_seqs)])
module = model.get_submodule(layer_format.format(k=k))
with SubstitutionHook(module, positions_dict=self.placeholders, values_dict=embeds):
generated = model.generate(tokens_batch, **generation_kwargs)
return generated