File size: 5,885 Bytes
d6682b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
"""
Contains utilities for extracting token representations and indices
from string templates. Used in computing the left and right vectors for ROME.
"""
from copy import deepcopy
from typing import List
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from ...util import nethook
def get_reprs_at_word_tokens(
model: AutoModelForCausalLM,
tok: AutoTokenizer,
context_templates: List[str],
words: List[str],
layer: int,
module_template: str,
subtoken: str,
track: str = "in",
) -> torch.Tensor:
"""
Retrieves the last token representation of `word` in `context_template`
when `word` is substituted into `context_template`. See `get_last_word_idx_in_template`
for more details.
"""
idxs = get_words_idxs_in_templates(tok, context_templates, words, subtoken)
return get_reprs_at_idxs(
model,
tok,
[context_templates[i].format(words[i]) for i in range(len(words))],
idxs,
layer,
module_template,
track,
)
def get_words_idxs_in_templates(
tok: AutoTokenizer, context_templates: str, words: str, subtoken: str
) -> int:
"""
Given list of template strings, each with *one* format specifier
(e.g. "{} plays basketball"), and words to be substituted into the
template, computes the post-tokenization index of their last tokens.
"""
assert all(
tmp.count("{}") == 1 for tmp in context_templates
), "We currently do not support multiple fill-ins for context"
prefixes_len, words_len, suffixes_len, inputs_len = [], [], [], []
for i, context in enumerate(context_templates):
prefix, suffix = context.split("{}")
prefix_len = len(tok.encode(prefix))
prompt_len = len(tok.encode(prefix + words[i]))
input_len = len(tok.encode(prefix + words[i] + suffix))
prefixes_len.append(prefix_len)
words_len.append(prompt_len - prefix_len)
suffixes_len.append(input_len - prompt_len)
inputs_len.append(input_len)
# Compute prefixes and suffixes of the tokenized context
# fill_idxs = [tmp.index("{}") for tmp in context_templates]
# prefixes, suffixes = [
# tmp[: fill_idxs[i]] for i, tmp in enumerate(context_templates)
# ], [tmp[fill_idxs[i] + 2 :] for i, tmp in enumerate(context_templates)]
# words = deepcopy(words)
#
# # Pre-process tokens
# for i, prefix in enumerate(prefixes):
# if len(prefix) > 0:
# assert prefix[-1] == " "
# prefix = prefix[:-1]
#
# prefixes[i] = prefix
# words[i] = f" {words[i].strip()}"
#
# # Tokenize to determine lengths
# assert len(prefixes) == len(words) == len(suffixes)
# n = len(prefixes)
# batch_tok = tok([*prefixes, *words, *suffixes])
# if 'input_ids' in batch_tok:
# batch_tok = batch_tok['input_ids']
# prefixes_tok, words_tok, suffixes_tok = [
# batch_tok[i : i + n] for i in range(0, n * 3, n)
# ]
# prefixes_len, words_len, suffixes_len = [
# [len(el) for el in tok_list]
# for tok_list in [prefixes_tok, words_tok, suffixes_tok]
# ]
# Compute indices of last tokens
if subtoken == "last" or subtoken == "first_after_last":
return [
[
prefixes_len[i]
+ words_len[i]
- (1 if subtoken == "last" or suffixes_len[i] == 0 else 0)
]
# If suffix is empty, there is no "first token after the last".
# So, just return the last token of the word.
for i in range(len(context_templates))
]
elif subtoken == "first":
return [[prefixes_len[i] - inputs_len[i]] for i in range(len(context_templates))]
else:
raise ValueError(f"Unknown subtoken type: {subtoken}")
def get_reprs_at_idxs(
model: AutoModelForCausalLM,
tok: AutoTokenizer,
contexts: List[str],#表示该知识的完整句子
idxs: List[List[int]],#被填入词的位置
layer: int,
module_template: str,
track: str = "in",
) -> torch.Tensor:
"""
Runs input through model and returns averaged representations of the tokens
at each index in `idxs`.
"""
def _batch(n):
for i in range(0, len(contexts), n):
yield contexts[i : i + n], idxs[i : i + n]#将句子和被填词位置分块
assert track in {"in", "out", "both"}
both = track == "both"
tin, tout = (
(track == "in" or both),
(track == "out" or both),
)#tin tout都是bool结构
module_name = module_template.format(layer)
to_return = {"in": [], "out": []}
def _process(cur_repr, batch_idxs, key):
nonlocal to_return
cur_repr = cur_repr[0] if type(cur_repr) is tuple else cur_repr
if cur_repr.shape[0]!=len(batch_idxs):
cur_repr=cur_repr.transpose(0,1)
for i, idx_list in enumerate(batch_idxs):
to_return[key].append(cur_repr[i][idx_list].mean(0))
for batch_contexts, batch_idxs in _batch(n=128):
#contexts_tok:[21 19]
contexts_tok = tok(batch_contexts, padding=True, return_tensors="pt").to(
next(model.parameters()).device
)
with torch.no_grad():
with nethook.Trace(
module=model,
layer=module_name,
retain_input=tin,
retain_output=tout,
) as tr:
model(**contexts_tok)
if tin:
_process(tr.input, batch_idxs, "in")
if tout:
_process(tr.output, batch_idxs, "out")
to_return = {k: torch.stack(v, 0) for k, v in to_return.items() if len(v) > 0}
if len(to_return) == 1:
return to_return["in"] if tin else to_return["out"]
else:
return to_return["in"], to_return["out"]
|