File size: 3,139 Bytes
e56055d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8335d0c
 
e56055d
 
8335d0c
 
 
 
 
 
 
 
e56055d
8335d0c
 
 
 
 
 
 
 
 
 
e56055d
8335d0c
 
 
 
 
 
 
 
e56055d
8335d0c
 
 
 
 
 
e56055d
 
a9853a7
e56055d
 
 
 
 
 
 
 
 
8335d0c
 
 
e56055d
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
import time

from editable_model import EditableModel
from utils import _last_encoder_state, _logits

class LU(EditableModel):
    """
    Representation lookup approach. Does not require training.
    """

    def __init__(self, model, config, model_constructor, memory=None):
        super().__init__(model, config, model_constructor)

        if "t5" not in self.config.model.name.lower():
            raise NotImplementedError
        self.memory = memory

    def lookup_replace(self, output, encoder_states):
        for i, encoder_state in enumerate(encoder_states):
            avg_encoder_state = encoder_state.detach().mean(0)
            memory_keys, memory_labels = self.memory
            dists = torch.norm(avg_encoder_state - memory_keys, dim=-1)
            closest_dist = dists.min()
            closest_idx = dists.argmin()
            closest_v = memory_labels[closest_idx]

            if closest_dist < self.config.lu.threshold:
                output[i] = torch.zeros((1, output.shape[1], output.shape[2]), device=output.device)
                for j, idx in enumerate(closest_v):
                    if j >= output.shape[1]:
                        break
                    output[i, j, idx] = self.config.lu.onehot_logit
                if "t5" not in self.config.model.name.lower():
                    # T5 does not shift targets in the loss
                    output[i] = output[i].roll(-1, -2)
        return output

    def generate(self, *inputs, **kwargs):
        model_output = self.model.generate(*inputs, **kwargs, output_hidden_states=True,
                                           output_scores=True, return_dict_in_generate=True)
        encoder_states = _last_encoder_state(model_output)
        output = _logits(model_output)
        if self.memory is not None:
            output = self.lookup_replace(output, encoder_states)
        return output.argmax(-1)

    def forward(self, *inputs, **kwargs):
        model_output = self.model(*inputs, **kwargs, output_hidden_states=True)
        encoder_states = _last_encoder_state(model_output)
        output = _logits(model_output)
        if self.memory is not None:
            output = self.lookup_replace(output, encoder_states)
        return output

    def edit(self, batch, condition=None, detach_history=False):
        edit_model = self.model.eval()
        if "bert" in self.config.model.name.lower():
            _, encoder_states = self.model(**batch, output_hidden_states=True)
        else:
            encoder_states = _last_encoder_state(self.model(**batch, output_hidden_states=True))

        memory_keys = []
        memory_labels = []
        for encoder_state, label in zip(encoder_states, batch["labels"]):
            avg_encoder_state = encoder_state.detach().mean(0)
            memory_keys.append(avg_encoder_state)
            memory_labels.append(label)

        memory = (torch.stack(memory_keys), torch.stack(memory_labels))
        return LU(self.model.eval(), self.config, self.model_constructor, memory), {}