Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils.data import Dataset | |
| import time | |
| from editable_model import EditableModel | |
| from utils import _last_encoder_state, _logits | |
| class LU(EditableModel): | |
| """ | |
| Representation lookup approach. Does not require training. | |
| """ | |
| def __init__(self, model, config, model_constructor, memory=None): | |
| super().__init__(model, config, model_constructor) | |
| if "t5" not in self.config.model.name.lower(): | |
| raise NotImplementedError | |
| self.memory = memory | |
| def lookup_replace(self, output, encoder_states): | |
| for i, encoder_state in enumerate(encoder_states): | |
| avg_encoder_state = encoder_state.detach().mean(0) | |
| memory_keys, memory_labels = self.memory | |
| dists = torch.norm(avg_encoder_state - memory_keys, dim=-1) | |
| closest_dist = dists.min() | |
| closest_idx = dists.argmin() | |
| closest_v = memory_labels[closest_idx] | |
| if closest_dist < self.config.lu.threshold: | |
| output[i] = torch.zeros((1, output.shape[1], output.shape[2]), device=output.device) | |
| for j, idx in enumerate(closest_v): | |
| if j >= output.shape[1]: | |
| break | |
| output[i, j, idx] = self.config.lu.onehot_logit | |
| if "t5" not in self.config.model.name.lower(): | |
| # T5 does not shift targets in the loss | |
| output[i] = output[i].roll(-1, -2) | |
| return output | |
| def generate(self, *inputs, **kwargs): | |
| model_output = self.model.generate(*inputs, **kwargs, output_hidden_states=True, | |
| output_scores=True, return_dict_in_generate=True) | |
| encoder_states = _last_encoder_state(model_output) | |
| output = _logits(model_output) | |
| if self.memory is not None: | |
| output = self.lookup_replace(output, encoder_states) | |
| return output.argmax(-1) | |
| def forward(self, *inputs, **kwargs): | |
| model_output = self.model(*inputs, **kwargs, output_hidden_states=True) | |
| encoder_states = _last_encoder_state(model_output) | |
| output = _logits(model_output) | |
| if self.memory is not None: | |
| output = self.lookup_replace(output, encoder_states) | |
| return output | |
| def edit(self, batch, condition=None, detach_history=False): | |
| edit_model = self.model.eval() | |
| if "bert" in self.config.model.name.lower(): | |
| _, encoder_states = self.model(**batch, output_hidden_states=True) | |
| else: | |
| encoder_states = _last_encoder_state(self.model(**batch, output_hidden_states=True)) | |
| memory_keys = [] | |
| memory_labels = [] | |
| for encoder_state, label in zip(encoder_states, batch["labels"]): | |
| avg_encoder_state = encoder_state.detach().mean(0) | |
| memory_keys.append(avg_encoder_state) | |
| memory_labels.append(label) | |
| memory = (torch.stack(memory_keys), torch.stack(memory_labels)) | |
| return LU(self.model.eval(), self.config, self.model_constructor, memory), {} | |