Charles Lin
Add algorithms from efk codebase
e56055d
raw
history blame
11.2 kB
# Adapted from https://github.com/nicola-decao/KnowledgeEditor/blob/main/src/models/one_shot_learner.py
"""
@inproceedings{decao2020editing,
title={Editing Factual Knowledge in Language Models},
author={Nicola De Cao and Wilker Aziz and Ivan Titov},
booktitle={arXiv pre-print 2104.08164},
url={https://arxiv.org/abs/2104.08164},
year={2021},
}
"""
import torch
import copy
import higher
from higher.patch import monkeypatch as make_functional
from allennlp.modules.feedforward import FeedForward
from allennlp.modules.seq2vec_encoders import PytorchSeq2VecWrapper
import logging
from editable_model import EditableModel
from utils import _logits, _inner_params
from models import BertClassifier
from transformers import BartForConditionalGeneration, T5ForConditionalGeneration
LOG = logging.getLogger(__name__)
class KE(EditableModel):
def __init__(self, model, config, model_constructor, editor=None):
super().__init__(model, config, model_constructor)
if editor is None:
if isinstance(model, BertClassifier):
embedding = model.model.embeddings.word_embeddings.weight.data
elif isinstance(model, BartForConditionalGeneration):
embedding = model.model.shared.weight.data
elif isinstance(model, T5ForConditionalGeneration):
embedding = model.shared.weight.data
else:
embedding = model.transformer.wte.weight.data
editor = OneShotLearner(model, vocab_dim=model.config.vocab_size,
include_set=config.model.inner_params,
embedding_dim=embedding.shape[-1],
embedding_init=embedding.clone().to(torch.float32),
max_scale=1)
self.editor = editor
def outer_parameters(self, grouped=False):
if grouped:
return [
dict(params=self.editor.parameters(), lr=self.config.lr)
]
else:
return list(self.editor.parameters())
def state_dict(self, destination=None, prefix="", keep_vars=False):
state_dict = super().state_dict(prefix=prefix, keep_vars=keep_vars) # Get default state dict
model_keys = self.model.state_dict(prefix=prefix, keep_vars=keep_vars).keys() # Remove model params
for k in model_keys:
del state_dict[f"model.{k}"]
state_dict["model_config"] = self.model.config # Include model config
return state_dict
def load_state_dict(self, state_dict, strict: bool = True):
config = state_dict["model_config"]
del state_dict["model_config"]
if config != self.model.config:
LOG.info("Loaded model config doesn't match current model config.")
LOG.info(f"Loaded: {config}")
LOG.info(f"Current: {self.model.config}")
res = super().load_state_dict(state_dict, False)
# We should only have missing keys for the model, and no unexpected keys
assert len([k for k in res.missing_keys if not k.startswith("model.")]) == 0, "Should only have missing keys for model."
assert len(res.unexpected_keys) == 0, "Shouldn't have any unexpected keys"
return res
def edit(self, batch, condition, detach_history=False):
outputs = _logits(self.model(**batch))
loss = self.edit_loss_fn(outputs, batch["labels"])["nll"]
names = set([n for n, p in self.model.named_parameters()])
pset = set(self.config.model.inner_params)
for p in pset:
assert p in names, f"inner param {p} not in model"
grads = torch.autograd.grad(
loss,
[p for (n, p) in _inner_params(self.model.named_parameters(), self.config.model.inner_params)]
)
params_dict = self.editor(
condition["input_ids"] if condition is not None else batch["input_ids"],
condition["attention_mask"] if condition is not None else batch["attention_mask"],
{n: g.to(torch.float32) for (n, g) in zip(self.config.model.inner_params, grads)},
)
edited_model = self.model
if not isinstance(edited_model, higher.patch._MonkeyPatchBase):
edited_model = make_functional(edited_model, in_place=True)
def new_param(n, p):
if n not in params_dict:
return p
if p.shape[0] == params_dict[n].shape[0]:
return p + params_dict[n]
else:
return p + params_dict[n].T
edited_model.update_params(
[new_param(n, p) for (n, p) in edited_model.named_parameters()]
)
if detach_history:
new_model = self.model_constructor()
new_model.load_state_dict(edited_model.state_dict())
edited_model = new_model
return KE(edited_model, self.config, self.model_constructor, editor=self.editor), {}
class ConditionedParameter(torch.nn.Module):
def __init__(self, parameter, condition_dim=1024, hidden_dim=128, max_scale=1):
super().__init__()
self.parameter_shape = parameter.shape
if len(self.parameter_shape) == 2:
self.conditioners = torch.nn.Sequential(
torch.nn.utils.weight_norm(torch.nn.Linear(condition_dim, hidden_dim)),
torch.nn.Tanh(),
torch.nn.utils.weight_norm(
torch.nn.Linear(
hidden_dim, 2 * (parameter.shape[0] + parameter.shape[1]) + 1
)
),
)
elif len(self.parameter_shape) == 1:
self.conditioners = torch.nn.Sequential(
torch.nn.utils.weight_norm(torch.nn.Linear(condition_dim, hidden_dim)),
torch.nn.Tanh(),
torch.nn.utils.weight_norm(
torch.nn.Linear(hidden_dim, 2 * parameter.shape[0] + 1)
),
)
else:
raise RuntimeError()
self.max_scale = max_scale
def forward(self, inputs, grad):
if inputs.shape[0] > 1:
raise RuntimeError("Can only condition on batches of size 1")
if len(self.parameter_shape) == 2:
(
conditioner_cola,
conditioner_rowa,
conditioner_colb,
conditioner_rowb,
conditioner_norm,
) = self.conditioners(inputs).split(
[
self.parameter_shape[1],
self.parameter_shape[0],
self.parameter_shape[1],
self.parameter_shape[0],
1,
],
dim=-1,
)
a = conditioner_rowa.softmax(-1).T @ conditioner_cola
b = conditioner_rowb.softmax(-1).T @ conditioner_colb
elif len(self.parameter_shape) == 1:
a, b, conditioner_norm = self.conditioners(inputs).split(
[self.parameter_shape[0], self.parameter_shape[0], 1], dim=-1
)
else:
raise RuntimeError()
if a.squeeze().shape[0] != grad.shape[0]:
return self.max_scale * conditioner_norm.sigmoid().squeeze() * (grad * a.squeeze().T + b.squeeze().T)
else:
return self.max_scale * conditioner_norm.sigmoid().squeeze() * (grad * a.squeeze() + b.squeeze())
class LSTMConditioner(torch.nn.Module):
def __init__(
self,
vocab_dim=30522,
embedding_dim=768,
hidden_dim=256,
output_dim=1024,
embedding_init=None,
):
super().__init__()
self.embedding = torch.nn.Embedding(
num_embeddings=vocab_dim,
embedding_dim=embedding_dim,
padding_idx=0,
_weight=embedding_init,
)
self.lstm = PytorchSeq2VecWrapper(
torch.nn.LSTM(
input_size=embedding_dim,
hidden_size=hidden_dim,
num_layers=1,
bidirectional=True,
batch_first=True,
)
)
self.linear = FeedForward(
input_dim=hidden_dim * 2,
num_layers=1,
hidden_dims=[output_dim],
activations=[torch.nn.Tanh()],
)
def forward(self, inputs, masks):
return self.linear(self.lstm(self.embedding(inputs), masks))
class OneShotLearner(torch.nn.Module):
def __init__(
self,
model,
vocab_dim,
embedding_dim=768,
hidden_dim=512,
condition_dim=768,
include_set={},
max_scale=1e-3,
embedding_init=None,
):
super().__init__()
self.param2conditioner_map = {
n: "{}_conditioner".format(n).replace(".", "_")
for n, p in model.named_parameters()
if n in include_set
}
self.conditioners = torch.nn.ModuleDict(
{
self.param2conditioner_map[n]: ConditionedParameter(
p,
condition_dim,
hidden_dim,
max_scale=max_scale,
)
for n, p in model.named_parameters()
if n in include_set
}
)
self.condition = LSTMConditioner(
vocab_dim,
embedding_dim,
hidden_dim,
condition_dim,
embedding_init=embedding_init,
)
def forward(self, inputs, masks, grads=None):
condition = self.condition(inputs, masks)
return {
p: self.conditioners[self.param2conditioner_map[p]](
condition,
grad=grads[p] if grads else None,
)
for p, c in self.param2conditioner_map.items()
}
if __name__ == '__main__':
import transformers
import types
model = transformers.GPT2LMHeadModel.from_pretrained("gpt2")
config = types.SimpleNamespace()
config.model.inner_params = [
"transformer.h.9.mlp.c_fc.weight",
"transformer.h.9.mlp.c_proj.weight",
"transformer.h.10.mlp.c_fc.weight",
"transformer.h.10.mlp.c_proj.weight",
"transformer.h.11.mlp.c_fc.weight",
"transformer.h.11.mlp.c_proj.weight",
]
efk = KE(model, config, lambda: copy.deepcopy(model)).cuda()
x = torch.arange(20).view(1, 20).cuda() + 1000
orig_logits = efk(x).logits
edited = efk.edit(x, masks=torch.ones_like(x), labels=x)
post_logits = efk(x).logits
assert torch.allclose(orig_logits, post_logits)
orig_param = [p for (n, p) in efk.model.named_parameters() if n == config.model.inner_params[-1]][0]
edited_param = [p for (n, p) in edited.model.named_parameters() if n == config.model.inner_params[-1]][0]
print((orig_param - edited_param).abs().max())
edited.eval()
print(efk(x, labels=x).loss, edited(x, labels=x).loss, edited.edit_loss_fn(edited(x).logits, x))["nll"]
edited2 = edited.edit(x, masks=torch.ones_like(x), labels=x)
print(efk(x, labels=x).loss, edited(x, labels=x).loss, edited2(x, labels=x).loss)
import pdb; pdb.set_trace()