Charles Lin
Add algorithms from efk codebase
e56055d
import torch
import torch.nn as nn
import higher
from editable_model import EditableModel
from utils import _logits
def fomaml_callback(all_grads):
return [g.detach() if g is not None else None for g in all_grads]
class ENN(EditableModel):
def __init__(self, model, config, model_constructor, edit_lrs=None, edit_loss_fn=None):
super().__init__(model, config, model_constructor)
if edit_lrs is None:
edit_lrs = nn.Parameter(torch.tensor([config.edit_lr] * len(self.config.model.inner_params)))
self.edit_lrs = edit_lrs
if edit_loss_fn is not None:
self.edit_loss_fn = edit_loss_fn
self.grad_callback = fomaml_callback if config.enn.first_order else lambda x: x
def outer_parameters(self, grouped=False):
extra_params = [self.edit_lrs]
if self.config.no_grad_layers is None:
model_params = self.model.parameters() if type(self.model.parameters()) == list else list(self.model.parameters())
else:
model_params = []
for m in self.model.modules():
if isinstance(m, nn.ModuleList):
model_params.extend(list(m[self.config.no_grad_layers:].parameters()))
if grouped:
return [
dict(params=model_params, lr=self.config.lr),
dict(params=extra_params, lr=self.config.lr_lr)
]
else:
return model_params + extra_params
def get_state_dict(self):
return self.state_dict()
def edit(self, batch, condition=None, detach_history=False):
opt = torch.optim.SGD([{"params": p, "lr": None}
for (n, p) in self.model.named_parameters() if n in self.config.model.inner_params])
with torch.enable_grad(), higher.innerloop_ctx(
self.model,
opt,
override={'lr': list(self.edit_lrs)},
copy_initial_weights=False,
track_higher_grads=self.training,
in_place=True
) as (fmodel, diffopt):
fmodel.eval()
for edit_step in range(self.config.enn.n_edit_steps):
output = _logits(fmodel(**batch))
loss = self.edit_loss_fn(output, batch["labels"])["nll"]
diffopt.step(loss, grad_callback=self.grad_callback)
if not detach_history:
model_edited = fmodel
else:
model_edited = self.model_constructor()
model_edited.load_state_dict(fmodel.state_dict())
model_edited.train(self.training)
return ENN(model_edited, self.config, self.model_constructor, edit_lrs=self.edit_lrs, edit_loss_fn=self.edit_loss_fn), {}
def test():
import transformers
import types
import copy
model = transformers.GPT2LMHeadModel.from_pretrained("gpt2")
config = types.SimpleNamespace()
config.edit_lr = 0.1
config.model.inner_params = [
"transformer.h.9.mlp.c_fc.weight",
"transformer.h.9.mlp.c_proj.weight",
"transformer.h.10.mlp.c_fc.weight",
"transformer.h.10.mlp.c_proj.weight",
"transformer.h.11.mlp.c_fc.weight",
"transformer.h.11.mlp.c_proj.weight",
]
config.enn = {
"n_edit_steps": 2,
"first_order": False
}
enn = ENN(model, config, lambda: copy.deepcopy(model)).cuda()
x = torch.arange(100).view(5, 20).cuda() + 1000
edited = enn.edit(x, masks=torch.ones_like(x), labels=x)
orig_param = [p for (n, p) in enn.model.named_parameters() if n == config.model.inner_params[-1]][0]
edited_param = [p for (n, p) in edited.model.named_parameters() if n == config.model.inner_params[-1]][0]
print((orig_param - edited_param).abs().max())
edited.eval()
print(enn(x, labels=x).loss, edited(x, labels=x).loss, edited.edit_loss_fn(edited(x).logits, x)["nll"])
edited.edit_loss_fn(edited(x).logits, x).backward()
import pdb; pdb.set_trace()
if __name__ == '__main__':
with torch.autograd.set_detect_anomaly(True):
test()