Spaces:
Runtime error
Runtime error
File size: 1,117 Bytes
e56055d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
import torch.nn as nn
from losses import masked_log_probs
from utils import _logits, shift_targets
class EditableModel(nn.Module):
def __init__(self, model, config, model_constructor):
super().__init__()
self.model = model
self.config = config
self.model_constructor = model_constructor
def _edit_loss_fn(pred, targ, **kwargs):
return masked_log_probs(pred, targ, shift=shift_targets(self.config), **kwargs)
self.edit_loss_fn = _edit_loss_fn
self.loc_loss_fn = _edit_loss_fn
def edit(self, batch, condition=None, detach_history=False):
raise NotImplementedError
def forward(self, *inputs, **kwargs):
return _logits(self.model(*inputs, **kwargs))
def outer_parameters(self, grouped=False):
if grouped:
return [dict(params=self.parameters(), lr=self.config.lr)]
else:
return list(self.parameters())
def generate(self, *args, **kwargs):
return self.model.generate(*args, **kwargs)
def base_loss(self, input_ids, attention_masks, label_ids):
pass
|