Spaces:
Runtime error
Runtime error
# Adapted from https://github.com/nicola-decao/KnowledgeEditor/blob/main/src/models/one_shot_learner.py | |
""" | |
@inproceedings{decao2020editing, | |
title={Editing Factual Knowledge in Language Models}, | |
author={Nicola De Cao and Wilker Aziz and Ivan Titov}, | |
booktitle={arXiv pre-print 2104.08164}, | |
url={https://arxiv.org/abs/2104.08164}, | |
year={2021}, | |
} | |
""" | |
import torch | |
import copy | |
import higher | |
from higher.patch import monkeypatch as make_functional | |
from allennlp.modules.feedforward import FeedForward | |
from allennlp.modules.seq2vec_encoders import PytorchSeq2VecWrapper | |
import logging | |
from editable_model import EditableModel | |
from utils import _logits, _inner_params | |
from models import BertClassifier | |
from transformers import BartForConditionalGeneration, T5ForConditionalGeneration | |
LOG = logging.getLogger(__name__) | |
class KE(EditableModel): | |
def __init__(self, model, config, model_constructor, editor=None): | |
super().__init__(model, config, model_constructor) | |
if editor is None: | |
if isinstance(model, BertClassifier): | |
embedding = model.model.embeddings.word_embeddings.weight.data | |
elif isinstance(model, BartForConditionalGeneration): | |
embedding = model.model.shared.weight.data | |
elif isinstance(model, T5ForConditionalGeneration): | |
embedding = model.shared.weight.data | |
else: | |
embedding = model.transformer.wte.weight.data | |
editor = OneShotLearner(model, vocab_dim=model.config.vocab_size, | |
include_set=config.model.inner_params, | |
embedding_dim=embedding.shape[-1], | |
embedding_init=embedding.clone().to(torch.float32), | |
max_scale=1) | |
self.editor = editor | |
def outer_parameters(self, grouped=False): | |
if grouped: | |
return [ | |
dict(params=self.editor.parameters(), lr=self.config.lr) | |
] | |
else: | |
return list(self.editor.parameters()) | |
def state_dict(self, destination=None, prefix="", keep_vars=False): | |
state_dict = super().state_dict(prefix=prefix, keep_vars=keep_vars) # Get default state dict | |
model_keys = self.model.state_dict(prefix=prefix, keep_vars=keep_vars).keys() # Remove model params | |
for k in model_keys: | |
del state_dict[f"model.{k}"] | |
state_dict["model_config"] = self.model.config # Include model config | |
return state_dict | |
def load_state_dict(self, state_dict, strict: bool = True): | |
config = state_dict["model_config"] | |
del state_dict["model_config"] | |
if config != self.model.config: | |
LOG.info("Loaded model config doesn't match current model config.") | |
LOG.info(f"Loaded: {config}") | |
LOG.info(f"Current: {self.model.config}") | |
res = super().load_state_dict(state_dict, False) | |
# We should only have missing keys for the model, and no unexpected keys | |
assert len([k for k in res.missing_keys if not k.startswith("model.")]) == 0, "Should only have missing keys for model." | |
assert len(res.unexpected_keys) == 0, "Shouldn't have any unexpected keys" | |
return res | |
def edit(self, batch, condition, detach_history=False): | |
outputs = _logits(self.model(**batch)) | |
loss = self.edit_loss_fn(outputs, batch["labels"])["nll"] | |
names = set([n for n, p in self.model.named_parameters()]) | |
pset = set(self.config.model.inner_params) | |
for p in pset: | |
assert p in names, f"inner param {p} not in model" | |
grads = torch.autograd.grad( | |
loss, | |
[p for (n, p) in _inner_params(self.model.named_parameters(), self.config.model.inner_params)] | |
) | |
params_dict = self.editor( | |
condition["input_ids"] if condition is not None else batch["input_ids"], | |
condition["attention_mask"] if condition is not None else batch["attention_mask"], | |
{n: g.to(torch.float32) for (n, g) in zip(self.config.model.inner_params, grads)}, | |
) | |
edited_model = self.model | |
if not isinstance(edited_model, higher.patch._MonkeyPatchBase): | |
edited_model = make_functional(edited_model, in_place=True) | |
def new_param(n, p): | |
if n not in params_dict: | |
return p | |
if p.shape[0] == params_dict[n].shape[0]: | |
return p + params_dict[n] | |
else: | |
return p + params_dict[n].T | |
edited_model.update_params( | |
[new_param(n, p) for (n, p) in edited_model.named_parameters()] | |
) | |
if detach_history: | |
new_model = self.model_constructor() | |
new_model.load_state_dict(edited_model.state_dict()) | |
edited_model = new_model | |
return KE(edited_model, self.config, self.model_constructor, editor=self.editor), {} | |
class ConditionedParameter(torch.nn.Module): | |
def __init__(self, parameter, condition_dim=1024, hidden_dim=128, max_scale=1): | |
super().__init__() | |
self.parameter_shape = parameter.shape | |
if len(self.parameter_shape) == 2: | |
self.conditioners = torch.nn.Sequential( | |
torch.nn.utils.weight_norm(torch.nn.Linear(condition_dim, hidden_dim)), | |
torch.nn.Tanh(), | |
torch.nn.utils.weight_norm( | |
torch.nn.Linear( | |
hidden_dim, 2 * (parameter.shape[0] + parameter.shape[1]) + 1 | |
) | |
), | |
) | |
elif len(self.parameter_shape) == 1: | |
self.conditioners = torch.nn.Sequential( | |
torch.nn.utils.weight_norm(torch.nn.Linear(condition_dim, hidden_dim)), | |
torch.nn.Tanh(), | |
torch.nn.utils.weight_norm( | |
torch.nn.Linear(hidden_dim, 2 * parameter.shape[0] + 1) | |
), | |
) | |
else: | |
raise RuntimeError() | |
self.max_scale = max_scale | |
def forward(self, inputs, grad): | |
if inputs.shape[0] > 1: | |
raise RuntimeError("Can only condition on batches of size 1") | |
if len(self.parameter_shape) == 2: | |
( | |
conditioner_cola, | |
conditioner_rowa, | |
conditioner_colb, | |
conditioner_rowb, | |
conditioner_norm, | |
) = self.conditioners(inputs).split( | |
[ | |
self.parameter_shape[1], | |
self.parameter_shape[0], | |
self.parameter_shape[1], | |
self.parameter_shape[0], | |
1, | |
], | |
dim=-1, | |
) | |
a = conditioner_rowa.softmax(-1).T @ conditioner_cola | |
b = conditioner_rowb.softmax(-1).T @ conditioner_colb | |
elif len(self.parameter_shape) == 1: | |
a, b, conditioner_norm = self.conditioners(inputs).split( | |
[self.parameter_shape[0], self.parameter_shape[0], 1], dim=-1 | |
) | |
else: | |
raise RuntimeError() | |
if a.squeeze().shape[0] != grad.shape[0]: | |
return self.max_scale * conditioner_norm.sigmoid().squeeze() * (grad * a.squeeze().T + b.squeeze().T) | |
else: | |
return self.max_scale * conditioner_norm.sigmoid().squeeze() * (grad * a.squeeze() + b.squeeze()) | |
class LSTMConditioner(torch.nn.Module): | |
def __init__( | |
self, | |
vocab_dim=30522, | |
embedding_dim=768, | |
hidden_dim=256, | |
output_dim=1024, | |
embedding_init=None, | |
): | |
super().__init__() | |
self.embedding = torch.nn.Embedding( | |
num_embeddings=vocab_dim, | |
embedding_dim=embedding_dim, | |
padding_idx=0, | |
_weight=embedding_init, | |
) | |
self.lstm = PytorchSeq2VecWrapper( | |
torch.nn.LSTM( | |
input_size=embedding_dim, | |
hidden_size=hidden_dim, | |
num_layers=1, | |
bidirectional=True, | |
batch_first=True, | |
) | |
) | |
self.linear = FeedForward( | |
input_dim=hidden_dim * 2, | |
num_layers=1, | |
hidden_dims=[output_dim], | |
activations=[torch.nn.Tanh()], | |
) | |
def forward(self, inputs, masks): | |
return self.linear(self.lstm(self.embedding(inputs), masks)) | |
class OneShotLearner(torch.nn.Module): | |
def __init__( | |
self, | |
model, | |
vocab_dim, | |
embedding_dim=768, | |
hidden_dim=512, | |
condition_dim=768, | |
include_set={}, | |
max_scale=1e-3, | |
embedding_init=None, | |
): | |
super().__init__() | |
self.param2conditioner_map = { | |
n: "{}_conditioner".format(n).replace(".", "_") | |
for n, p in model.named_parameters() | |
if n in include_set | |
} | |
self.conditioners = torch.nn.ModuleDict( | |
{ | |
self.param2conditioner_map[n]: ConditionedParameter( | |
p, | |
condition_dim, | |
hidden_dim, | |
max_scale=max_scale, | |
) | |
for n, p in model.named_parameters() | |
if n in include_set | |
} | |
) | |
self.condition = LSTMConditioner( | |
vocab_dim, | |
embedding_dim, | |
hidden_dim, | |
condition_dim, | |
embedding_init=embedding_init, | |
) | |
def forward(self, inputs, masks, grads=None): | |
condition = self.condition(inputs, masks) | |
return { | |
p: self.conditioners[self.param2conditioner_map[p]]( | |
condition, | |
grad=grads[p] if grads else None, | |
) | |
for p, c in self.param2conditioner_map.items() | |
} | |
if __name__ == '__main__': | |
import transformers | |
import types | |
model = transformers.GPT2LMHeadModel.from_pretrained("gpt2") | |
config = types.SimpleNamespace() | |
config.model.inner_params = [ | |
"transformer.h.9.mlp.c_fc.weight", | |
"transformer.h.9.mlp.c_proj.weight", | |
"transformer.h.10.mlp.c_fc.weight", | |
"transformer.h.10.mlp.c_proj.weight", | |
"transformer.h.11.mlp.c_fc.weight", | |
"transformer.h.11.mlp.c_proj.weight", | |
] | |
efk = KE(model, config, lambda: copy.deepcopy(model)).cuda() | |
x = torch.arange(20).view(1, 20).cuda() + 1000 | |
orig_logits = efk(x).logits | |
edited = efk.edit(x, masks=torch.ones_like(x), labels=x) | |
post_logits = efk(x).logits | |
assert torch.allclose(orig_logits, post_logits) | |
orig_param = [p for (n, p) in efk.model.named_parameters() if n == config.model.inner_params[-1]][0] | |
edited_param = [p for (n, p) in edited.model.named_parameters() if n == config.model.inner_params[-1]][0] | |
print((orig_param - edited_param).abs().max()) | |
edited.eval() | |
print(efk(x, labels=x).loss, edited(x, labels=x).loss, edited.edit_loss_fn(edited(x).logits, x))["nll"] | |
edited2 = edited.edit(x, masks=torch.ones_like(x), labels=x) | |
print(efk(x, labels=x).loss, edited(x, labels=x).loss, edited2(x, labels=x).loss) | |
import pdb; pdb.set_trace() | |