File size: 1,117 Bytes
e56055d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch.nn as nn

from losses import masked_log_probs
from utils import _logits, shift_targets


class EditableModel(nn.Module):
    def __init__(self, model, config, model_constructor):
        super().__init__()

        self.model = model
        self.config = config
        self.model_constructor = model_constructor

        def _edit_loss_fn(pred, targ, **kwargs):
            return masked_log_probs(pred, targ, shift=shift_targets(self.config), **kwargs)
        self.edit_loss_fn = _edit_loss_fn
        self.loc_loss_fn = _edit_loss_fn

    def edit(self, batch, condition=None, detach_history=False):
        raise NotImplementedError

    def forward(self, *inputs, **kwargs):
        return _logits(self.model(*inputs, **kwargs))

    def outer_parameters(self, grouped=False):
        if grouped:
            return [dict(params=self.parameters(), lr=self.config.lr)]
        else:
            return list(self.parameters())

    def generate(self, *args, **kwargs):
        return self.model.generate(*args, **kwargs)

    def base_loss(self, input_ids, attention_masks, label_ids):
        pass