Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import higher | |
from editable_model import EditableModel | |
from utils import _logits | |
def fomaml_callback(all_grads): | |
return [g.detach() if g is not None else None for g in all_grads] | |
class ENN(EditableModel): | |
def __init__(self, model, config, model_constructor, edit_lrs=None, edit_loss_fn=None): | |
super().__init__(model, config, model_constructor) | |
if edit_lrs is None: | |
edit_lrs = nn.Parameter(torch.tensor([config.edit_lr] * len(self.config.model.inner_params))) | |
self.edit_lrs = edit_lrs | |
if edit_loss_fn is not None: | |
self.edit_loss_fn = edit_loss_fn | |
self.grad_callback = fomaml_callback if config.enn.first_order else lambda x: x | |
def outer_parameters(self, grouped=False): | |
extra_params = [self.edit_lrs] | |
if self.config.no_grad_layers is None: | |
model_params = self.model.parameters() if type(self.model.parameters()) == list else list(self.model.parameters()) | |
else: | |
model_params = [] | |
for m in self.model.modules(): | |
if isinstance(m, nn.ModuleList): | |
model_params.extend(list(m[self.config.no_grad_layers:].parameters())) | |
if grouped: | |
return [ | |
dict(params=model_params, lr=self.config.lr), | |
dict(params=extra_params, lr=self.config.lr_lr) | |
] | |
else: | |
return model_params + extra_params | |
def get_state_dict(self): | |
return self.state_dict() | |
def edit(self, batch, condition=None, detach_history=False): | |
opt = torch.optim.SGD([{"params": p, "lr": None} | |
for (n, p) in self.model.named_parameters() if n in self.config.model.inner_params]) | |
with torch.enable_grad(), higher.innerloop_ctx( | |
self.model, | |
opt, | |
override={'lr': list(self.edit_lrs)}, | |
copy_initial_weights=False, | |
track_higher_grads=self.training, | |
in_place=True | |
) as (fmodel, diffopt): | |
fmodel.eval() | |
for edit_step in range(self.config.enn.n_edit_steps): | |
output = _logits(fmodel(**batch)) | |
loss = self.edit_loss_fn(output, batch["labels"])["nll"] | |
diffopt.step(loss, grad_callback=self.grad_callback) | |
if not detach_history: | |
model_edited = fmodel | |
else: | |
model_edited = self.model_constructor() | |
model_edited.load_state_dict(fmodel.state_dict()) | |
model_edited.train(self.training) | |
return ENN(model_edited, self.config, self.model_constructor, edit_lrs=self.edit_lrs, edit_loss_fn=self.edit_loss_fn), {} | |
def test(): | |
import transformers | |
import types | |
import copy | |
model = transformers.GPT2LMHeadModel.from_pretrained("gpt2") | |
config = types.SimpleNamespace() | |
config.edit_lr = 0.1 | |
config.model.inner_params = [ | |
"transformer.h.9.mlp.c_fc.weight", | |
"transformer.h.9.mlp.c_proj.weight", | |
"transformer.h.10.mlp.c_fc.weight", | |
"transformer.h.10.mlp.c_proj.weight", | |
"transformer.h.11.mlp.c_fc.weight", | |
"transformer.h.11.mlp.c_proj.weight", | |
] | |
config.enn = { | |
"n_edit_steps": 2, | |
"first_order": False | |
} | |
enn = ENN(model, config, lambda: copy.deepcopy(model)).cuda() | |
x = torch.arange(100).view(5, 20).cuda() + 1000 | |
edited = enn.edit(x, masks=torch.ones_like(x), labels=x) | |
orig_param = [p for (n, p) in enn.model.named_parameters() if n == config.model.inner_params[-1]][0] | |
edited_param = [p for (n, p) in edited.model.named_parameters() if n == config.model.inner_params[-1]][0] | |
print((orig_param - edited_param).abs().max()) | |
edited.eval() | |
print(enn(x, labels=x).loss, edited(x, labels=x).loss, edited.edit_loss_fn(edited(x).logits, x)["nll"]) | |
edited.edit_loss_fn(edited(x).logits, x).backward() | |
import pdb; pdb.set_trace() | |
if __name__ == '__main__': | |
with torch.autograd.set_detect_anomaly(True): | |
test() | |