|
import os |
|
import re |
|
from collections import defaultdict |
|
import numpy as np |
|
import torch |
|
from torch import nn |
|
from contextlib import AbstractContextManager |
|
|
|
|
|
|
|
def item(x): |
|
return np.array(x).item() |
|
|
|
def _prompt_to_parts(prompt, repeat=5): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
split_prompt = re.split(r' *\[X\]', prompt) |
|
parts = [] |
|
for i in range(len(split_prompt)): |
|
cur_part = split_prompt[i] |
|
if cur_part != '': |
|
|
|
parts.append(cur_part) |
|
if i < len(split_prompt) - 1: |
|
parts.extend([0] * repeat) |
|
print('Prompt parts:', parts) |
|
return parts |
|
|
|
|
|
class Hook(AbstractContextManager): |
|
|
|
|
|
def __init__(self, module, fn): |
|
self.registered_hook = module.register_forward_hook(fn) |
|
|
|
def __enter__(self): |
|
return self |
|
|
|
def __exit__(self, type, value, traceback): |
|
self.close() |
|
|
|
def close(self): |
|
self.registered_hook.remove() |
|
|
|
|
|
class SubstitutionHook(Hook): |
|
|
|
def __init__(self, module, positions_dict, values_dict): |
|
assert set(positions_dict.keys()) == set(values_dict.keys()) |
|
keys = positions_dict.keys() |
|
|
|
def fn(module, input, output): |
|
device = output[0].device |
|
dtype = output[0].dtype |
|
|
|
for key in keys: |
|
num_positions = len(positions_dict[key]) |
|
values = values_dict[key].unsqueeze(1).expand(-1, num_positions, -1) |
|
positions = positions_dict[key] |
|
print(f'{positions=} {values.shape=} {output[0].shape=}') |
|
output[0][:, positions, :] = values.to(dtype).to(device) |
|
self.registered_hook.remove() |
|
return output |
|
|
|
self.registered_hook = module.register_forward_hook(fn) |
|
|
|
|
|
|
|
class InterpretationPrompt: |
|
def __init__(self, tokenizer, prompt, placeholder_token=' '): |
|
prompt_parts = _prompt_to_parts(prompt) |
|
if placeholder_token is None: |
|
placeholder_token_id = tokenizer.eos_token_id |
|
else: |
|
placeholder_token_id = item(tokenizer.encode(placeholder_token, add_special_tokens=False)) |
|
assert placeholder_token_id != tokenizer.eos_token_id |
|
self.tokens = [] |
|
self.placeholders = defaultdict(list) |
|
for part in prompt_parts: |
|
if type(part) == str: |
|
self.tokens.extend(tokenizer.encode(part, add_special_tokens=False)) |
|
elif type(part) == int: |
|
self.placeholders[part].append(len(self.tokens)) |
|
self.tokens.append(placeholder_token_id) |
|
else: |
|
raise NotImplementedError |
|
|
|
def generate(self, model, embeds, k, layer_format='model.layers.{k}', **generation_kwargs): |
|
num_seqs = len(embeds[0]) |
|
tokens_batch = torch.tensor([self.tokens[:] for _ in range(num_seqs)]) |
|
module = model.get_submodule(layer_format.format(k=k)) |
|
with SubstitutionHook(module, positions_dict=self.placeholders, values_dict=embeds): |
|
generated = model.generate(tokens_batch, **generation_kwargs) |
|
return generated |
|
|