Charles Lin
All algs except KE working.
8335d0c
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
import time
from editable_model import EditableModel
from utils import _last_encoder_state, _logits
class LU(EditableModel):
"""
Representation lookup approach. Does not require training.
"""
def __init__(self, model, config, model_constructor, memory=None):
super().__init__(model, config, model_constructor)
if "t5" not in self.config.model.name.lower():
raise NotImplementedError
self.memory = memory
def lookup_replace(self, output, encoder_states):
for i, encoder_state in enumerate(encoder_states):
avg_encoder_state = encoder_state.detach().mean(0)
memory_keys, memory_labels = self.memory
dists = torch.norm(avg_encoder_state - memory_keys, dim=-1)
closest_dist = dists.min()
closest_idx = dists.argmin()
closest_v = memory_labels[closest_idx]
if closest_dist < self.config.lu.threshold:
output[i] = torch.zeros((1, output.shape[1], output.shape[2]), device=output.device)
for j, idx in enumerate(closest_v):
if j >= output.shape[1]:
break
output[i, j, idx] = self.config.lu.onehot_logit
if "t5" not in self.config.model.name.lower():
# T5 does not shift targets in the loss
output[i] = output[i].roll(-1, -2)
return output
def generate(self, *inputs, **kwargs):
model_output = self.model.generate(*inputs, **kwargs, output_hidden_states=True,
output_scores=True, return_dict_in_generate=True)
encoder_states = _last_encoder_state(model_output)
output = _logits(model_output)
if self.memory is not None:
output = self.lookup_replace(output, encoder_states)
return output.argmax(-1)
def forward(self, *inputs, **kwargs):
model_output = self.model(*inputs, **kwargs, output_hidden_states=True)
encoder_states = _last_encoder_state(model_output)
output = _logits(model_output)
if self.memory is not None:
output = self.lookup_replace(output, encoder_states)
return output
def edit(self, batch, condition=None, detach_history=False):
edit_model = self.model.eval()
if "bert" in self.config.model.name.lower():
_, encoder_states = self.model(**batch, output_hidden_states=True)
else:
encoder_states = _last_encoder_state(self.model(**batch, output_hidden_states=True))
memory_keys = []
memory_labels = []
for encoder_state, label in zip(encoder_states, batch["labels"]):
avg_encoder_state = encoder_state.detach().mean(0)
memory_keys.append(avg_encoder_state)
memory_labels.append(label)
memory = (torch.stack(memory_keys), torch.stack(memory_labels))
return LU(self.model.eval(), self.config, self.model_constructor, memory), {}