import torch import torch.nn as nn import higher from editable_model import EditableModel from utils import _logits def fomaml_callback(all_grads): return [g.detach() if g is not None else None for g in all_grads] class ENN(EditableModel): def __init__(self, model, config, model_constructor, edit_lrs=None, edit_loss_fn=None): super().__init__(model, config, model_constructor) if edit_lrs is None: edit_lrs = nn.Parameter(torch.tensor([config.edit_lr] * len(self.config.model.inner_params))) self.edit_lrs = edit_lrs if edit_loss_fn is not None: self.edit_loss_fn = edit_loss_fn self.grad_callback = fomaml_callback if config.enn.first_order else lambda x: x def outer_parameters(self, grouped=False): extra_params = [self.edit_lrs] if self.config.no_grad_layers is None: model_params = self.model.parameters() if type(self.model.parameters()) == list else list(self.model.parameters()) else: model_params = [] for m in self.model.modules(): if isinstance(m, nn.ModuleList): model_params.extend(list(m[self.config.no_grad_layers:].parameters())) if grouped: return [ dict(params=model_params, lr=self.config.lr), dict(params=extra_params, lr=self.config.lr_lr) ] else: return model_params + extra_params def get_state_dict(self): return self.state_dict() def edit(self, batch, condition=None, detach_history=False): opt = torch.optim.SGD([{"params": p, "lr": None} for (n, p) in self.model.named_parameters() if n in self.config.model.inner_params]) with torch.enable_grad(), higher.innerloop_ctx( self.model, opt, override={'lr': list(self.edit_lrs)}, copy_initial_weights=False, track_higher_grads=self.training, in_place=True ) as (fmodel, diffopt): fmodel.eval() for edit_step in range(self.config.enn.n_edit_steps): output = _logits(fmodel(**batch)) loss = self.edit_loss_fn(output, batch["labels"])["nll"] diffopt.step(loss, grad_callback=self.grad_callback) if not detach_history: model_edited = fmodel else: model_edited = self.model_constructor() model_edited.load_state_dict(fmodel.state_dict()) model_edited.train(self.training) return ENN(model_edited, self.config, self.model_constructor, edit_lrs=self.edit_lrs, edit_loss_fn=self.edit_loss_fn), {} def test(): import transformers import types import copy model = transformers.GPT2LMHeadModel.from_pretrained("gpt2") config = types.SimpleNamespace() config.edit_lr = 0.1 config.model.inner_params = [ "transformer.h.9.mlp.c_fc.weight", "transformer.h.9.mlp.c_proj.weight", "transformer.h.10.mlp.c_fc.weight", "transformer.h.10.mlp.c_proj.weight", "transformer.h.11.mlp.c_fc.weight", "transformer.h.11.mlp.c_proj.weight", ] config.enn = { "n_edit_steps": 2, "first_order": False } enn = ENN(model, config, lambda: copy.deepcopy(model)).cuda() x = torch.arange(100).view(5, 20).cuda() + 1000 edited = enn.edit(x, masks=torch.ones_like(x), labels=x) orig_param = [p for (n, p) in enn.model.named_parameters() if n == config.model.inner_params[-1]][0] edited_param = [p for (n, p) in edited.model.named_parameters() if n == config.model.inner_params[-1]][0] print((orig_param - edited_param).abs().max()) edited.eval() print(enn(x, labels=x).loss, edited(x, labels=x).loss, edited.edit_loss_fn(edited(x).logits, x)["nll"]) edited.edit_loss_fn(edited(x).logits, x).backward() import pdb; pdb.set_trace() if __name__ == '__main__': with torch.autograd.set_detect_anomaly(True): test()