model-editing / editable_model.py
Charles Lin
Add algorithms from efk codebase
e56055d
import torch.nn as nn
from losses import masked_log_probs
from utils import _logits, shift_targets
class EditableModel(nn.Module):
def __init__(self, model, config, model_constructor):
super().__init__()
self.model = model
self.config = config
self.model_constructor = model_constructor
def _edit_loss_fn(pred, targ, **kwargs):
return masked_log_probs(pred, targ, shift=shift_targets(self.config), **kwargs)
self.edit_loss_fn = _edit_loss_fn
self.loc_loss_fn = _edit_loss_fn
def edit(self, batch, condition=None, detach_history=False):
raise NotImplementedError
def forward(self, *inputs, **kwargs):
return _logits(self.model(*inputs, **kwargs))
def outer_parameters(self, grouped=False):
if grouped:
return [dict(params=self.parameters(), lr=self.config.lr)]
else:
return list(self.parameters())
def generate(self, *args, **kwargs):
return self.model.generate(*args, **kwargs)
def base_loss(self, input_ids, attention_masks, label_ids):
pass