Spaces:
Runtime error
Runtime error
File size: 5,010 Bytes
e56055d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import torch
import torch.nn as nn
import higher
from higher.patch import monkeypatch as make_functional
import time
from editable_model import EditableModel
from utils import _logits, _inner_params
from losses import kl_loc_loss
class FT(EditableModel):
"""
Fine-tuning approach. Does not require training.
"""
def __init__(self, model, config, model_constructor, edit_loss_fn=None):
super().__init__(model, config, model_constructor)
if edit_loss_fn is not None:
self.edit_loss_fn = edit_loss_fn
self.locality_loss_fn = kl_loc_loss
self.loc_ids = None
self.loc_masks = None
self.loc_sampler = None
def _edit_loss(self, model, p0, p_edited, edit_batch):
output = _logits(model(**edit_batch, params=p_edited))
loss_dict = self.edit_loss_fn(output, edit_batch["labels"])
l_edit, acc = loss_dict["nll"], loss_dict["acc"]
if self.config.ft.locality.enabled:
if self.config.ft.locality.oracle:
loc_batch = next(self.loc_sampler)["loc"]
else:
raise NotImplementedError
with torch.no_grad():
original_base_logits = _logits(model(**loc_batch, params=p0))
edited_base_logits = _logits(model(**loc_batch, params=p_edited))
kl_mask = loc_batch.get("decoder_attention_mask", loc_batch["attention_mask"])
l_loc = self.locality_loss_fn(original_base_logits, edited_base_logits, mask=kl_mask)
loss = l_loc + self.config.ft.locality.cedit * l_edit
else:
l_loc = torch.tensor(float('nan'))
loss = l_edit
return loss, l_edit, l_loc, acc
def accuracy(self, output, labels):
if output.shape[-1] != 1:
shifted_output = output.argmax(-1)[:, :-1]
shifted_labels = labels[:, 1:]
to_predict = (shifted_labels != -100).sum()
correct = (shifted_output == shifted_labels).sum()
acc = correct.float() / to_predict.float()
else:
acc = ((output > 0) == labels.bool()).sum().float()
return acc
def _edit_status(self, step, loss, l_edit, l_loc, acc, res_p):
return (
f"step: {step}".ljust(14) +
f"loss: {loss.item():.5f}".ljust(18) +
f"l_edit: {l_edit.item():.5f}".ljust(18) +
f"l_loc: {l_loc.item():.5f}".ljust(18) +
f"acc: {acc.item():.2f}".ljust(14) +
f"norm: {res_p.view(-1).norm().item():.5f}"
)
def edit(self, batch, condition=None, detach_history=False):
edit_model = self.model.eval()
p0 = list(edit_model.named_parameters())
if not isinstance(edit_model, higher.patch._MonkeyPatchBase):
edit_model = make_functional(self.model, track_higher_grads=False, in_place=True)
packed_residuals = {}
opt_params = []
for n, p in _inner_params(edit_model.named_parameters(), self.config.model.inner_params):
if self.config.ft.rank is not None:
u = nn.Parameter(torch.randn(p.shape[0], self.config.ft.rank, device=p.device) * self.config.ft.init_std)
v = nn.Parameter(torch.zeros(self.config.ft.rank, p.shape[1], device=p.device))
res = [u, v]
else:
res = [nn.Parameter(torch.zeros_like(p, device=p.device))]
packed_residuals[n] = res
opt_params.extend(res)
assert len(opt_params) == len(self.config.model.inner_params)
OptClass = getattr(torch.optim, self.config.ft.opt)
opt = OptClass(opt_params, lr=self.config.edit_lr)
start_time = time.time()
for edit_step in range(self.config.ft.max_edit_steps):
if self.config.ft.time_limit is not None and (time.time() - start_time > self.config.ft.time_limit):
break
residuals = {k: v[0] @ v[1] if len(v) == 2 else v[0] for k, v in packed_residuals.items()}
edited_params = [p if n not in residuals else p.detach() + residuals[n] for n, p in p0]
loss, l_edit, l_loc, acc = self._edit_loss(edit_model, [p for n, p in p0], edited_params, batch)
if self.config.ft.verbose:
residual = list(residuals.values())[-1]
print(self._edit_status(edit_step, loss, l_edit, l_loc, acc, residual), end="\r")
if acc == 1.0:
break
for p, g in zip(opt_params, torch.autograd.grad(loss, opt_params)):
p.grad = g
torch.nn.utils.clip_grad_norm_(opt_params, self.config.grad_clip)
opt.step()
opt.zero_grad()
if detach_history:
new_model = self.model_constructor()
new_model.load_state_dict(edit_model.state_dict())
edit_model = new_model
edit_model.train(self.training)
return FT(edit_model, self.config, self.model_constructor, self.edit_loss_fn), {}
|