Charles Lin commited on
Commit
e56055d
1 Parent(s): 717a51e

Add algorithms from efk codebase

Browse files
Files changed (15) hide show
  1. algs/enn.py +114 -0
  2. algs/ft.py +121 -0
  3. algs/ke.py +312 -0
  4. algs/lu.py +90 -0
  5. algs/mend.py +297 -0
  6. algs/serac.py +452 -0
  7. app.py +1 -0
  8. editable_model.py +36 -0
  9. hooks.py +28 -0
  10. losses.py +181 -0
  11. metrics.py +135 -0
  12. models.py +196 -0
  13. nn.py +362 -0
  14. requirements.txt +6 -0
  15. utils.py +441 -0
algs/enn.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import higher
4
+
5
+ from editable_model import EditableModel
6
+ from utils import _logits
7
+
8
+
9
+ def fomaml_callback(all_grads):
10
+ return [g.detach() if g is not None else None for g in all_grads]
11
+
12
+
13
+ class ENN(EditableModel):
14
+ def __init__(self, model, config, model_constructor, edit_lrs=None, edit_loss_fn=None):
15
+ super().__init__(model, config, model_constructor)
16
+
17
+ if edit_lrs is None:
18
+ edit_lrs = nn.Parameter(torch.tensor([config.edit_lr] * len(self.config.model.inner_params)))
19
+ self.edit_lrs = edit_lrs
20
+
21
+ if edit_loss_fn is not None:
22
+ self.edit_loss_fn = edit_loss_fn
23
+
24
+ self.grad_callback = fomaml_callback if config.enn.first_order else lambda x: x
25
+
26
+ def outer_parameters(self, grouped=False):
27
+ extra_params = [self.edit_lrs]
28
+ if self.config.no_grad_layers is None:
29
+ model_params = self.model.parameters() if type(self.model.parameters()) == list else list(self.model.parameters())
30
+ else:
31
+ model_params = []
32
+ for m in self.model.modules():
33
+ if isinstance(m, nn.ModuleList):
34
+ model_params.extend(list(m[self.config.no_grad_layers:].parameters()))
35
+
36
+ if grouped:
37
+ return [
38
+ dict(params=model_params, lr=self.config.lr),
39
+ dict(params=extra_params, lr=self.config.lr_lr)
40
+ ]
41
+ else:
42
+ return model_params + extra_params
43
+
44
+ def get_state_dict(self):
45
+ return self.state_dict()
46
+
47
+ def edit(self, batch, condition=None, detach_history=False):
48
+ opt = torch.optim.SGD([{"params": p, "lr": None}
49
+ for (n, p) in self.model.named_parameters() if n in self.config.model.inner_params])
50
+ with torch.enable_grad(), higher.innerloop_ctx(
51
+ self.model,
52
+ opt,
53
+ override={'lr': list(self.edit_lrs)},
54
+ copy_initial_weights=False,
55
+ track_higher_grads=self.training,
56
+ in_place=True
57
+ ) as (fmodel, diffopt):
58
+ fmodel.eval()
59
+ for edit_step in range(self.config.enn.n_edit_steps):
60
+ output = _logits(fmodel(**batch))
61
+ loss = self.edit_loss_fn(output, batch["labels"])["nll"]
62
+ diffopt.step(loss, grad_callback=self.grad_callback)
63
+
64
+ if not detach_history:
65
+ model_edited = fmodel
66
+ else:
67
+ model_edited = self.model_constructor()
68
+ model_edited.load_state_dict(fmodel.state_dict())
69
+ model_edited.train(self.training)
70
+
71
+ return ENN(model_edited, self.config, self.model_constructor, edit_lrs=self.edit_lrs, edit_loss_fn=self.edit_loss_fn), {}
72
+
73
+
74
+ def test():
75
+ import transformers
76
+ import types
77
+ import copy
78
+
79
+ model = transformers.GPT2LMHeadModel.from_pretrained("gpt2")
80
+
81
+ config = types.SimpleNamespace()
82
+ config.edit_lr = 0.1
83
+ config.model.inner_params = [
84
+ "transformer.h.9.mlp.c_fc.weight",
85
+ "transformer.h.9.mlp.c_proj.weight",
86
+ "transformer.h.10.mlp.c_fc.weight",
87
+ "transformer.h.10.mlp.c_proj.weight",
88
+ "transformer.h.11.mlp.c_fc.weight",
89
+ "transformer.h.11.mlp.c_proj.weight",
90
+ ]
91
+ config.enn = {
92
+ "n_edit_steps": 2,
93
+ "first_order": False
94
+ }
95
+
96
+ enn = ENN(model, config, lambda: copy.deepcopy(model)).cuda()
97
+
98
+ x = torch.arange(100).view(5, 20).cuda() + 1000
99
+
100
+ edited = enn.edit(x, masks=torch.ones_like(x), labels=x)
101
+
102
+ orig_param = [p for (n, p) in enn.model.named_parameters() if n == config.model.inner_params[-1]][0]
103
+ edited_param = [p for (n, p) in edited.model.named_parameters() if n == config.model.inner_params[-1]][0]
104
+
105
+ print((orig_param - edited_param).abs().max())
106
+ edited.eval()
107
+ print(enn(x, labels=x).loss, edited(x, labels=x).loss, edited.edit_loss_fn(edited(x).logits, x)["nll"])
108
+ edited.edit_loss_fn(edited(x).logits, x).backward()
109
+ import pdb; pdb.set_trace()
110
+
111
+
112
+ if __name__ == '__main__':
113
+ with torch.autograd.set_detect_anomaly(True):
114
+ test()
algs/ft.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import higher
4
+ from higher.patch import monkeypatch as make_functional
5
+ import time
6
+
7
+ from editable_model import EditableModel
8
+ from utils import _logits, _inner_params
9
+ from losses import kl_loc_loss
10
+
11
+
12
+ class FT(EditableModel):
13
+ """
14
+ Fine-tuning approach. Does not require training.
15
+ """
16
+
17
+ def __init__(self, model, config, model_constructor, edit_loss_fn=None):
18
+ super().__init__(model, config, model_constructor)
19
+
20
+ if edit_loss_fn is not None:
21
+ self.edit_loss_fn = edit_loss_fn
22
+
23
+ self.locality_loss_fn = kl_loc_loss
24
+ self.loc_ids = None
25
+ self.loc_masks = None
26
+ self.loc_sampler = None
27
+
28
+ def _edit_loss(self, model, p0, p_edited, edit_batch):
29
+ output = _logits(model(**edit_batch, params=p_edited))
30
+ loss_dict = self.edit_loss_fn(output, edit_batch["labels"])
31
+ l_edit, acc = loss_dict["nll"], loss_dict["acc"]
32
+ if self.config.ft.locality.enabled:
33
+ if self.config.ft.locality.oracle:
34
+ loc_batch = next(self.loc_sampler)["loc"]
35
+ else:
36
+ raise NotImplementedError
37
+
38
+ with torch.no_grad():
39
+ original_base_logits = _logits(model(**loc_batch, params=p0))
40
+ edited_base_logits = _logits(model(**loc_batch, params=p_edited))
41
+ kl_mask = loc_batch.get("decoder_attention_mask", loc_batch["attention_mask"])
42
+ l_loc = self.locality_loss_fn(original_base_logits, edited_base_logits, mask=kl_mask)
43
+ loss = l_loc + self.config.ft.locality.cedit * l_edit
44
+ else:
45
+ l_loc = torch.tensor(float('nan'))
46
+ loss = l_edit
47
+ return loss, l_edit, l_loc, acc
48
+
49
+ def accuracy(self, output, labels):
50
+ if output.shape[-1] != 1:
51
+ shifted_output = output.argmax(-1)[:, :-1]
52
+ shifted_labels = labels[:, 1:]
53
+ to_predict = (shifted_labels != -100).sum()
54
+ correct = (shifted_output == shifted_labels).sum()
55
+ acc = correct.float() / to_predict.float()
56
+ else:
57
+ acc = ((output > 0) == labels.bool()).sum().float()
58
+ return acc
59
+
60
+ def _edit_status(self, step, loss, l_edit, l_loc, acc, res_p):
61
+ return (
62
+ f"step: {step}".ljust(14) +
63
+ f"loss: {loss.item():.5f}".ljust(18) +
64
+ f"l_edit: {l_edit.item():.5f}".ljust(18) +
65
+ f"l_loc: {l_loc.item():.5f}".ljust(18) +
66
+ f"acc: {acc.item():.2f}".ljust(14) +
67
+ f"norm: {res_p.view(-1).norm().item():.5f}"
68
+ )
69
+
70
+ def edit(self, batch, condition=None, detach_history=False):
71
+ edit_model = self.model.eval()
72
+ p0 = list(edit_model.named_parameters())
73
+
74
+ if not isinstance(edit_model, higher.patch._MonkeyPatchBase):
75
+ edit_model = make_functional(self.model, track_higher_grads=False, in_place=True)
76
+
77
+ packed_residuals = {}
78
+ opt_params = []
79
+ for n, p in _inner_params(edit_model.named_parameters(), self.config.model.inner_params):
80
+ if self.config.ft.rank is not None:
81
+ u = nn.Parameter(torch.randn(p.shape[0], self.config.ft.rank, device=p.device) * self.config.ft.init_std)
82
+ v = nn.Parameter(torch.zeros(self.config.ft.rank, p.shape[1], device=p.device))
83
+ res = [u, v]
84
+ else:
85
+ res = [nn.Parameter(torch.zeros_like(p, device=p.device))]
86
+
87
+ packed_residuals[n] = res
88
+ opt_params.extend(res)
89
+
90
+ assert len(opt_params) == len(self.config.model.inner_params)
91
+ OptClass = getattr(torch.optim, self.config.ft.opt)
92
+ opt = OptClass(opt_params, lr=self.config.edit_lr)
93
+
94
+ start_time = time.time()
95
+ for edit_step in range(self.config.ft.max_edit_steps):
96
+ if self.config.ft.time_limit is not None and (time.time() - start_time > self.config.ft.time_limit):
97
+ break
98
+ residuals = {k: v[0] @ v[1] if len(v) == 2 else v[0] for k, v in packed_residuals.items()}
99
+ edited_params = [p if n not in residuals else p.detach() + residuals[n] for n, p in p0]
100
+ loss, l_edit, l_loc, acc = self._edit_loss(edit_model, [p for n, p in p0], edited_params, batch)
101
+
102
+ if self.config.ft.verbose:
103
+ residual = list(residuals.values())[-1]
104
+ print(self._edit_status(edit_step, loss, l_edit, l_loc, acc, residual), end="\r")
105
+
106
+ if acc == 1.0:
107
+ break
108
+
109
+ for p, g in zip(opt_params, torch.autograd.grad(loss, opt_params)):
110
+ p.grad = g
111
+ torch.nn.utils.clip_grad_norm_(opt_params, self.config.grad_clip)
112
+ opt.step()
113
+ opt.zero_grad()
114
+
115
+ if detach_history:
116
+ new_model = self.model_constructor()
117
+ new_model.load_state_dict(edit_model.state_dict())
118
+ edit_model = new_model
119
+ edit_model.train(self.training)
120
+
121
+ return FT(edit_model, self.config, self.model_constructor, self.edit_loss_fn), {}
algs/ke.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/nicola-decao/KnowledgeEditor/blob/main/src/models/one_shot_learner.py
2
+ """
3
+ @inproceedings{decao2020editing,
4
+ title={Editing Factual Knowledge in Language Models},
5
+ author={Nicola De Cao and Wilker Aziz and Ivan Titov},
6
+ booktitle={arXiv pre-print 2104.08164},
7
+ url={https://arxiv.org/abs/2104.08164},
8
+ year={2021},
9
+ }
10
+ """
11
+
12
+ import torch
13
+ import copy
14
+ import higher
15
+ from higher.patch import monkeypatch as make_functional
16
+ from allennlp.modules.feedforward import FeedForward
17
+ from allennlp.modules.seq2vec_encoders import PytorchSeq2VecWrapper
18
+ import logging
19
+
20
+ from editable_model import EditableModel
21
+ from utils import _logits, _inner_params
22
+ from models import BertClassifier
23
+ from transformers import BartForConditionalGeneration, T5ForConditionalGeneration
24
+
25
+
26
+ LOG = logging.getLogger(__name__)
27
+
28
+
29
+ class KE(EditableModel):
30
+ def __init__(self, model, config, model_constructor, editor=None):
31
+ super().__init__(model, config, model_constructor)
32
+
33
+ if editor is None:
34
+ if isinstance(model, BertClassifier):
35
+ embedding = model.model.embeddings.word_embeddings.weight.data
36
+ elif isinstance(model, BartForConditionalGeneration):
37
+ embedding = model.model.shared.weight.data
38
+ elif isinstance(model, T5ForConditionalGeneration):
39
+ embedding = model.shared.weight.data
40
+ else:
41
+ embedding = model.transformer.wte.weight.data
42
+
43
+ editor = OneShotLearner(model, vocab_dim=model.config.vocab_size,
44
+ include_set=config.model.inner_params,
45
+ embedding_dim=embedding.shape[-1],
46
+ embedding_init=embedding.clone().to(torch.float32),
47
+ max_scale=1)
48
+ self.editor = editor
49
+
50
+ def outer_parameters(self, grouped=False):
51
+ if grouped:
52
+ return [
53
+ dict(params=self.editor.parameters(), lr=self.config.lr)
54
+ ]
55
+ else:
56
+ return list(self.editor.parameters())
57
+
58
+ def state_dict(self, destination=None, prefix="", keep_vars=False):
59
+ state_dict = super().state_dict(prefix=prefix, keep_vars=keep_vars) # Get default state dict
60
+ model_keys = self.model.state_dict(prefix=prefix, keep_vars=keep_vars).keys() # Remove model params
61
+ for k in model_keys:
62
+ del state_dict[f"model.{k}"]
63
+ state_dict["model_config"] = self.model.config # Include model config
64
+ return state_dict
65
+
66
+ def load_state_dict(self, state_dict, strict: bool = True):
67
+ config = state_dict["model_config"]
68
+ del state_dict["model_config"]
69
+ if config != self.model.config:
70
+ LOG.info("Loaded model config doesn't match current model config.")
71
+ LOG.info(f"Loaded: {config}")
72
+ LOG.info(f"Current: {self.model.config}")
73
+
74
+ res = super().load_state_dict(state_dict, False)
75
+ # We should only have missing keys for the model, and no unexpected keys
76
+ assert len([k for k in res.missing_keys if not k.startswith("model.")]) == 0, "Should only have missing keys for model."
77
+ assert len(res.unexpected_keys) == 0, "Shouldn't have any unexpected keys"
78
+ return res
79
+
80
+ def edit(self, batch, condition, detach_history=False):
81
+ outputs = _logits(self.model(**batch))
82
+ loss = self.edit_loss_fn(outputs, batch["labels"])["nll"]
83
+
84
+ names = set([n for n, p in self.model.named_parameters()])
85
+ pset = set(self.config.model.inner_params)
86
+ for p in pset:
87
+ assert p in names, f"inner param {p} not in model"
88
+
89
+ grads = torch.autograd.grad(
90
+ loss,
91
+ [p for (n, p) in _inner_params(self.model.named_parameters(), self.config.model.inner_params)]
92
+ )
93
+
94
+ params_dict = self.editor(
95
+ condition["input_ids"] if condition is not None else batch["input_ids"],
96
+ condition["attention_mask"] if condition is not None else batch["attention_mask"],
97
+ {n: g.to(torch.float32) for (n, g) in zip(self.config.model.inner_params, grads)},
98
+ )
99
+
100
+ edited_model = self.model
101
+ if not isinstance(edited_model, higher.patch._MonkeyPatchBase):
102
+ edited_model = make_functional(edited_model, in_place=True)
103
+
104
+ def new_param(n, p):
105
+ if n not in params_dict:
106
+ return p
107
+
108
+ if p.shape[0] == params_dict[n].shape[0]:
109
+ return p + params_dict[n]
110
+ else:
111
+ return p + params_dict[n].T
112
+
113
+ edited_model.update_params(
114
+ [new_param(n, p) for (n, p) in edited_model.named_parameters()]
115
+ )
116
+
117
+ if detach_history:
118
+ new_model = self.model_constructor()
119
+ new_model.load_state_dict(edited_model.state_dict())
120
+ edited_model = new_model
121
+
122
+ return KE(edited_model, self.config, self.model_constructor, editor=self.editor), {}
123
+
124
+
125
+ class ConditionedParameter(torch.nn.Module):
126
+ def __init__(self, parameter, condition_dim=1024, hidden_dim=128, max_scale=1):
127
+ super().__init__()
128
+ self.parameter_shape = parameter.shape
129
+
130
+ if len(self.parameter_shape) == 2:
131
+ self.conditioners = torch.nn.Sequential(
132
+ torch.nn.utils.weight_norm(torch.nn.Linear(condition_dim, hidden_dim)),
133
+ torch.nn.Tanh(),
134
+ torch.nn.utils.weight_norm(
135
+ torch.nn.Linear(
136
+ hidden_dim, 2 * (parameter.shape[0] + parameter.shape[1]) + 1
137
+ )
138
+ ),
139
+ )
140
+ elif len(self.parameter_shape) == 1:
141
+ self.conditioners = torch.nn.Sequential(
142
+ torch.nn.utils.weight_norm(torch.nn.Linear(condition_dim, hidden_dim)),
143
+ torch.nn.Tanh(),
144
+ torch.nn.utils.weight_norm(
145
+ torch.nn.Linear(hidden_dim, 2 * parameter.shape[0] + 1)
146
+ ),
147
+ )
148
+ else:
149
+ raise RuntimeError()
150
+
151
+ self.max_scale = max_scale
152
+
153
+ def forward(self, inputs, grad):
154
+ if inputs.shape[0] > 1:
155
+ raise RuntimeError("Can only condition on batches of size 1")
156
+
157
+ if len(self.parameter_shape) == 2:
158
+ (
159
+ conditioner_cola,
160
+ conditioner_rowa,
161
+ conditioner_colb,
162
+ conditioner_rowb,
163
+ conditioner_norm,
164
+ ) = self.conditioners(inputs).split(
165
+ [
166
+ self.parameter_shape[1],
167
+ self.parameter_shape[0],
168
+ self.parameter_shape[1],
169
+ self.parameter_shape[0],
170
+ 1,
171
+ ],
172
+ dim=-1,
173
+ )
174
+
175
+ a = conditioner_rowa.softmax(-1).T @ conditioner_cola
176
+ b = conditioner_rowb.softmax(-1).T @ conditioner_colb
177
+
178
+ elif len(self.parameter_shape) == 1:
179
+ a, b, conditioner_norm = self.conditioners(inputs).split(
180
+ [self.parameter_shape[0], self.parameter_shape[0], 1], dim=-1
181
+ )
182
+ else:
183
+ raise RuntimeError()
184
+
185
+ if a.squeeze().shape[0] != grad.shape[0]:
186
+ return self.max_scale * conditioner_norm.sigmoid().squeeze() * (grad * a.squeeze().T + b.squeeze().T)
187
+ else:
188
+ return self.max_scale * conditioner_norm.sigmoid().squeeze() * (grad * a.squeeze() + b.squeeze())
189
+
190
+
191
+ class LSTMConditioner(torch.nn.Module):
192
+ def __init__(
193
+ self,
194
+ vocab_dim=30522,
195
+ embedding_dim=768,
196
+ hidden_dim=256,
197
+ output_dim=1024,
198
+ embedding_init=None,
199
+ ):
200
+ super().__init__()
201
+ self.embedding = torch.nn.Embedding(
202
+ num_embeddings=vocab_dim,
203
+ embedding_dim=embedding_dim,
204
+ padding_idx=0,
205
+ _weight=embedding_init,
206
+ )
207
+ self.lstm = PytorchSeq2VecWrapper(
208
+ torch.nn.LSTM(
209
+ input_size=embedding_dim,
210
+ hidden_size=hidden_dim,
211
+ num_layers=1,
212
+ bidirectional=True,
213
+ batch_first=True,
214
+ )
215
+ )
216
+ self.linear = FeedForward(
217
+ input_dim=hidden_dim * 2,
218
+ num_layers=1,
219
+ hidden_dims=[output_dim],
220
+ activations=[torch.nn.Tanh()],
221
+ )
222
+
223
+ def forward(self, inputs, masks):
224
+ return self.linear(self.lstm(self.embedding(inputs), masks))
225
+
226
+
227
+ class OneShotLearner(torch.nn.Module):
228
+ def __init__(
229
+ self,
230
+ model,
231
+ vocab_dim,
232
+ embedding_dim=768,
233
+ hidden_dim=512,
234
+ condition_dim=768,
235
+ include_set={},
236
+ max_scale=1e-3,
237
+ embedding_init=None,
238
+ ):
239
+ super().__init__()
240
+
241
+ self.param2conditioner_map = {
242
+ n: "{}_conditioner".format(n).replace(".", "_")
243
+ for n, p in model.named_parameters()
244
+ if n in include_set
245
+ }
246
+
247
+ self.conditioners = torch.nn.ModuleDict(
248
+ {
249
+ self.param2conditioner_map[n]: ConditionedParameter(
250
+ p,
251
+ condition_dim,
252
+ hidden_dim,
253
+ max_scale=max_scale,
254
+ )
255
+ for n, p in model.named_parameters()
256
+ if n in include_set
257
+ }
258
+ )
259
+
260
+ self.condition = LSTMConditioner(
261
+ vocab_dim,
262
+ embedding_dim,
263
+ hidden_dim,
264
+ condition_dim,
265
+ embedding_init=embedding_init,
266
+ )
267
+
268
+ def forward(self, inputs, masks, grads=None):
269
+ condition = self.condition(inputs, masks)
270
+ return {
271
+ p: self.conditioners[self.param2conditioner_map[p]](
272
+ condition,
273
+ grad=grads[p] if grads else None,
274
+ )
275
+ for p, c in self.param2conditioner_map.items()
276
+ }
277
+
278
+
279
+ if __name__ == '__main__':
280
+ import transformers
281
+ import types
282
+
283
+ model = transformers.GPT2LMHeadModel.from_pretrained("gpt2")
284
+
285
+ config = types.SimpleNamespace()
286
+ config.model.inner_params = [
287
+ "transformer.h.9.mlp.c_fc.weight",
288
+ "transformer.h.9.mlp.c_proj.weight",
289
+ "transformer.h.10.mlp.c_fc.weight",
290
+ "transformer.h.10.mlp.c_proj.weight",
291
+ "transformer.h.11.mlp.c_fc.weight",
292
+ "transformer.h.11.mlp.c_proj.weight",
293
+ ]
294
+
295
+ efk = KE(model, config, lambda: copy.deepcopy(model)).cuda()
296
+
297
+ x = torch.arange(20).view(1, 20).cuda() + 1000
298
+ orig_logits = efk(x).logits
299
+ edited = efk.edit(x, masks=torch.ones_like(x), labels=x)
300
+ post_logits = efk(x).logits
301
+
302
+ assert torch.allclose(orig_logits, post_logits)
303
+
304
+ orig_param = [p for (n, p) in efk.model.named_parameters() if n == config.model.inner_params[-1]][0]
305
+ edited_param = [p for (n, p) in edited.model.named_parameters() if n == config.model.inner_params[-1]][0]
306
+
307
+ print((orig_param - edited_param).abs().max())
308
+ edited.eval()
309
+ print(efk(x, labels=x).loss, edited(x, labels=x).loss, edited.edit_loss_fn(edited(x).logits, x))["nll"]
310
+ edited2 = edited.edit(x, masks=torch.ones_like(x), labels=x)
311
+ print(efk(x, labels=x).loss, edited(x, labels=x).loss, edited2(x, labels=x).loss)
312
+ import pdb; pdb.set_trace()
algs/lu.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.utils.data import Dataset
5
+ import time
6
+
7
+ from editable_model import EditableModel
8
+ from utils import _last_encoder_state, _logits
9
+
10
+ class LU(EditableModel):
11
+ """
12
+ Representation lookup approach. Does not require training.
13
+ """
14
+
15
+ def __init__(self, model, config, model_constructor, memory=None):
16
+ super().__init__(model, config, model_constructor)
17
+
18
+ self.memory = memory
19
+
20
+ def forward(self, *inputs, **kwargs):
21
+ if "bert" in self.config.model.name.lower():
22
+ output, encoder_states = self.model(*inputs, **kwargs, output_hidden_states=True)
23
+ else:
24
+ model_output = self.model(*inputs, **kwargs, output_hidden_states=True)
25
+ encoder_states = _last_encoder_state(model_output)
26
+ output = _logits(model_output)
27
+
28
+ if self.memory is not None:
29
+ for i, encoder_state in enumerate(encoder_states):
30
+ if "gpt2" in self.config.model.name.lower():
31
+ # NOTE: broken
32
+ memory_prefixes, memory_labels = self.memory
33
+ prefix_means = encoder_state.cumsum(0).detach() / torch.arange(1, encoder_state.shape[0] + 1, device=encoder_state.device).view(-1, 1)
34
+ dist_mat = (prefix_means.unsqueeze(1) - memory_prefixes.unsqueeze(0)).norm(2, dim=-1)
35
+
36
+ min_dists, min_idxs = dist_mat.min(-1)
37
+ memory_mask = (min_dists < self.config.lu.threshold)
38
+ onehot_logits = self.config.lu.onehot_logit * F.one_hot(memory_labels[min_idxs], output.shape[-1]).float()
39
+ output[i, memory_mask] = onehot_logits[memory_mask]
40
+ elif "bart" in self.config.model.name.lower() or "t5" in self.config.model.name.lower():
41
+ avg_encoder_state = encoder_state.detach().mean(0)
42
+ memory_keys, memory_labels = self.memory
43
+ dists = torch.norm(avg_encoder_state - memory_keys, dim=-1)
44
+ closest_dist = dists.min()
45
+ closest_idx = dists.argmin()
46
+ closest_v = memory_labels[closest_idx]
47
+
48
+ if closest_dist < self.config.lu.threshold:
49
+ output[i] = torch.zeros((1, kwargs['labels'].shape[1], output.shape[2]), device=output.device)
50
+ for j, idx in enumerate(closest_v):
51
+ if j >= output.shape[1]:
52
+ break
53
+ output[i, j, idx] = self.config.lu.onehot_logit
54
+ if "t5" not in self.config.model.name.lower():
55
+ # T5 does not shift targets in the loss
56
+ output[i] = output[i].roll(-1, -2)
57
+ else:
58
+ avg_encoder_state = encoder_state.detach().mean(0)
59
+ memory_keys, memory_labels = self.memory
60
+ dists = torch.norm(avg_encoder_state - memory_keys, dim=-1)
61
+ closest_dist = dists.min()
62
+ closest_idx = dists.argmin()
63
+ closest_v = memory_labels[closest_idx]
64
+
65
+ if closest_dist < self.config.lu.threshold:
66
+ output[i] = self.config.lu.onehot_logit * (2 * closest_v - 1) # Return onehot_logit or -onehot_logit
67
+
68
+ return output
69
+
70
+ def edit(self, batch, condition=None):
71
+ edit_model = self.model.eval()
72
+ if "bert" in self.config.model.name.lower():
73
+ _, encoder_states = self.model(**batch, output_hidden_states=True)
74
+ else:
75
+ encoder_states = _last_encoder_state(self.model(**batch, output_hidden_states=True))
76
+
77
+ memory_keys = []
78
+ memory_labels = []
79
+ for encoder_state, label in zip(encoder_states, batch["labels"]):
80
+ if "gpt2" in self.config.model.name.lower():
81
+ # NOTE: broken
82
+ avg_encoder_states = (encoder_state.cumsum(0).detach() / torch.arange(1, encoder_state.shape[0] + 1, device=encoder_state.device).view(-1, 1))[-10:, :]
83
+ memory = (avg_encoder_states, label[-10:])
84
+ else:
85
+ avg_encoder_state = encoder_state.detach().mean(0)
86
+ memory_keys.append(avg_encoder_state)
87
+ memory_labels.append(label)
88
+
89
+ memory = (torch.stack(memory_keys), torch.stack(memory_labels))
90
+ return LU(self.model.eval(), self.config, self.model_constructor, memory), {}
algs/mend.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import copy
5
+ import transformers
6
+ import higher
7
+ import logging
8
+ from higher.patch import monkeypatch as make_functional
9
+ from collections import defaultdict
10
+
11
+ from editable_model import EditableModel
12
+ from hooks import hook_model
13
+ import nn as local_nn
14
+ from utils import _logits, _inner_params
15
+
16
+ LOG = logging.getLogger(__name__)
17
+
18
+
19
+ def update_counter(x, m, s, k):
20
+ new_m = m + (x - m) / k
21
+ new_s = s + (x - m) * (x - new_m)
22
+
23
+ return new_m, new_s
24
+
25
+
26
+ class GradientTransform(nn.Module):
27
+ def __init__(self, x_dim: int, delta_dim: int, cfg, n_modes = None):
28
+ super().__init__()
29
+
30
+ self.x_dim = x_dim
31
+ self.delta_dim = delta_dim
32
+ self.cfg = cfg
33
+ if cfg.combine and (cfg.one_sided or cfg.x_only or cfg.delta_only):
34
+ raise ValueError("cfg.combine cannot be used with one-sided MEND variants")
35
+
36
+ self.norm_init = False
37
+ self.register_buffer("u_mean", torch.full((x_dim,), float("nan")))
38
+ self.register_buffer("v_mean", torch.full((delta_dim,), float("nan")))
39
+ self.register_buffer("u_std", torch.full((x_dim,), float("nan")))
40
+ self.register_buffer("v_std", torch.full((delta_dim,), float("nan")))
41
+ self.register_buffer("u_s", torch.full((x_dim,), float("nan")))
42
+ self.register_buffer("v_s", torch.full((delta_dim,), float("nan")))
43
+ self.register_buffer("k", torch.full((1,), float("nan")))
44
+
45
+ MlpClass = getattr(local_nn, cfg.mlp_class)
46
+ LOG.info(f"Building Gradient Transform with MLP class {MlpClass}")
47
+
48
+ def delta_net():
49
+ return MlpClass(delta_dim, delta_dim, delta_dim * 2, cfg.n_hidden, init=cfg.init, act=cfg.act, rank=cfg.rank, n_modes=n_modes)
50
+
51
+ def x_net():
52
+ return MlpClass(x_dim, x_dim, x_dim * 2, cfg.n_hidden, init=cfg.init, act=cfg.act, rank=cfg.rank, n_modes=n_modes)
53
+
54
+ def combined_net():
55
+ return MlpClass(delta_dim + x_dim, delta_dim + x_dim, (delta_dim + x_dim) * 2,
56
+ cfg.n_hidden, init=cfg.init, act=cfg.act, rank=cfg.rank, n_modes=n_modes)
57
+
58
+ def ID():
59
+ return lambda x, mode=None: x
60
+
61
+ if cfg.combine:
62
+ self.mlp = combined_net()
63
+ elif cfg.one_sided:
64
+ if x_dim > delta_dim:
65
+ self.mlp1, self.mlp2 = ID(), delta_net()
66
+ else:
67
+ self.mlp1, self.mlp2 = x_net(), ID()
68
+ elif cfg.x_only:
69
+ self.mlp1, self.mlp2 = x_net(), ID()
70
+ elif cfg.delta_only:
71
+ self.mlp1, self.mlp2 = ID(), delta_net()
72
+ else:
73
+ self.mlp1, self.mlp2 = x_net(), delta_net()
74
+
75
+ def forward(self, u, v, param_idx=None):
76
+ u, v = u.to(torch.float32), v.to(torch.float32)
77
+
78
+ u_ = u.view(-1, u.shape[-1])
79
+ v_ = v.view(-1, v.shape[-1])
80
+
81
+ nz_mask = (u_ != 0).any(-1) * (v_ != 0).any(-1) # Skip batch elements with zero grad
82
+ u_ = u_[nz_mask]
83
+ v_ = v_[nz_mask]
84
+
85
+ if self.training:
86
+ for idx in range(u_.shape[0]):
87
+ if not self.norm_init:
88
+ self.u_mean = u_[idx].clone().detach()
89
+ self.v_mean = v_[idx].clone().detach()
90
+ self.u_s.zero_()
91
+ self.v_s.zero_()
92
+ self.k[:] = 1
93
+ self.norm_init = True
94
+ else:
95
+ self.k += 1
96
+ self.u_mean, self.u_s = update_counter(u_[idx], self.u_mean, self.u_s, self.k)
97
+ self.v_mean, self.v_s = update_counter(v_[idx], self.v_mean, self.v_s, self.k)
98
+
99
+ if self.k < 2:
100
+ raise RuntimeError(f"Can't perform normalization with only {self.k} samples so far")
101
+ self.u_std = (self.u_s / (self.k - 1)) ** 0.5
102
+ self.v_std = (self.v_s / (self.k - 1)) ** 0.5
103
+
104
+ if self.cfg.norm:
105
+ u_input = (u_ - self.u_mean) / (self.u_std + 1e-7)
106
+ v_input = (v_ - self.v_mean) / (self.v_std + 1e-7)
107
+ else:
108
+ u_input = u_
109
+ v_input = v_
110
+
111
+ if self.cfg.combine:
112
+ output = self.mlp(torch.cat((u_input, v_input), -1), mode=param_idx)
113
+ out1, out2 = output.split([u.shape[-1], v.shape[-1]], -1)
114
+ return out1, out2
115
+ else:
116
+ return self.mlp1(u_input, mode=param_idx), self.mlp2(v_input, mode=param_idx)
117
+
118
+
119
+ class MEND(EditableModel):
120
+ def get_shape(self, p):
121
+ # We need to (annoyingly) flip the shapes since OpenAI gpt2 uses convs instead of linear
122
+ return p.shape if isinstance(self.model, transformers.GPT2LMHeadModel) else (p.shape[1], p.shape[0])
123
+
124
+ def __init__(self, model, config, model_constructor, gtn=None, edit_lrs=None):
125
+ super().__init__(model, config, model_constructor)
126
+
127
+ if edit_lrs is None:
128
+ edit_lrs = nn.Parameter(torch.tensor([config.edit_lr] * len(self.config.model.inner_params)))
129
+ self.edit_lrs = edit_lrs
130
+
131
+ if not hasattr(self.model, "handles"):
132
+ hook_model(self.model, self.config.model.inner_params)
133
+ LOG.info(f"Hooked {len(self.model.handles)//2} modules")
134
+
135
+ if config.gtn.shared:
136
+ shape_dict = defaultdict(list)
137
+ for n, p in _inner_params(model.named_parameters(), self.config.model.inner_params):
138
+ shape_dict[self.get_shape(p)].append(n)
139
+ self.shape_dict = shape_dict
140
+
141
+ if gtn is None:
142
+ if not config.gtn.shared:
143
+ self.gtn = nn.ModuleDict({
144
+ n.replace(".", "#"): GradientTransform(*self.get_shape(p), config.gtn)
145
+ for (n, p) in _inner_params(model.named_parameters(), self.config.model.inner_params)
146
+ })
147
+ else:
148
+ self.gtn = nn.ModuleDict({
149
+ str(tuple(s)): GradientTransform(*s, config.gtn, len(shape_dict[s]))
150
+ for s in shape_dict.keys()
151
+ })
152
+ else:
153
+ self.gtn = gtn
154
+
155
+ def state_dict(self, destination=None, prefix="", keep_vars=False):
156
+ state_dict = super().state_dict(prefix=prefix, keep_vars=keep_vars) # Get default state dict
157
+ model_keys = self.model.state_dict(prefix=prefix, keep_vars=keep_vars).keys() # Remove model params
158
+ for k in model_keys:
159
+ del state_dict[f"model.{k}"]
160
+ state_dict["model_config"] = self.model.config # Include model config
161
+ return state_dict
162
+
163
+ def load_state_dict(self, state_dict, strict: bool = True):
164
+ config = state_dict["model_config"]
165
+ del state_dict["model_config"]
166
+ if config != self.model.config:
167
+ LOG.info("Loaded model config doesn't match current model config.")
168
+ LOG.info(f"Loaded: {config}")
169
+ LOG.info(f"Current: {self.model.config}")
170
+
171
+ res = super().load_state_dict(state_dict, False)
172
+ # We should only have missing keys for the model, and no unexpected keys
173
+ assert len([k for k in res.missing_keys if not k.startswith("model.")]) == 0, "Should only have missing keys for model."
174
+ assert len(res.unexpected_keys) == 0, "Shouldn't have any unexpected keys"
175
+ return res
176
+
177
+ def outer_parameters(self, grouped=False):
178
+ if grouped:
179
+ return [
180
+ dict(params=list(self.gtn.parameters()), lr=self.config.lr),
181
+ dict(params=[self.edit_lrs], lr=self.config.lr_lr)
182
+ ]
183
+ else:
184
+ return list(self.gtn.parameters()) + [self.edit_lrs]
185
+
186
+ def edit(self, batch, condition=None, detach_history=False):
187
+ outputs = _logits(self.model(**batch))
188
+ loss = self.edit_loss_fn(outputs, batch["labels"])["nll"]
189
+
190
+ names = set([n for n, p in self.model.named_parameters()])
191
+ pset = set(self.config.model.inner_params)
192
+ for p in pset:
193
+ assert p in names, f"inner param {p} not in model"
194
+
195
+ loss.backward()
196
+
197
+ if self.config.gtn.shared:
198
+ param_idx = lambda n, p: self.shape_dict[self.get_shape(p)].index(n) if self.config.gtn.shared else None # noqa: E731
199
+ transformed_factors = {
200
+ n: self.gtn[str(tuple(self.get_shape(p)))](p.__x__, p.__delta__, param_idx(n, p))
201
+ for n, p in _inner_params(self.model.named_parameters(), self.config.model.inner_params)
202
+ }
203
+ else:
204
+ transformed_factors = {
205
+ n: self.gtn[n.replace(".", "#")](p.__x__, p.__delta__)
206
+ for n, p in _inner_params(self.model.named_parameters(), self.config.model.inner_params)
207
+ }
208
+
209
+ # Should be bi,bj->ji for nn.Linear, but [annoying] GPT2 uses Conv1d instead...
210
+ if isinstance(self.model, transformers.GPT2LMHeadModel):
211
+ targ = "ij"
212
+ else:
213
+ targ = "ji"
214
+ mean_grads = {
215
+ n: torch.einsum(f"bi,bj->{targ}", x, delta)
216
+ for n, (x, delta) in transformed_factors.items()
217
+ }
218
+
219
+ info_dict = {}
220
+ idx = 0
221
+ for n, p in _inner_params(self.model.named_parameters(), self.config.model.inner_params):
222
+ info_dict[f"grad/true_mag{idx}"] = p.grad.norm(2).item()
223
+ info_dict[f"grad/pseudo_mag{idx}"] = mean_grads[n].norm(2).item()
224
+ info_dict[f"grad/true_std{idx}"] = p.grad.std().item()
225
+ info_dict[f"grad/pseudo_std{idx}"] = mean_grads[n].std().item()
226
+ info_dict[f"grad/diff{idx}"] = (p.grad - mean_grads[n]).norm(2).item()
227
+ info_dict[f"grad/cos{idx}"] = F.cosine_similarity(p.grad.reshape(-1), mean_grads[n].reshape(-1), dim=0).item()
228
+ idx += 1
229
+
230
+ self.model.zero_grad()
231
+
232
+ assert len(self.edit_lrs) == len(list(mean_grads.items()))
233
+ updates = {n: lr * g for lr, (n, g) in zip(self.edit_lrs, mean_grads.items())}
234
+
235
+ edited_model = self.model
236
+ if not isinstance(edited_model, higher.patch._MonkeyPatchBase):
237
+ edited_model = make_functional(edited_model, in_place=True)
238
+
239
+ new_params = []
240
+ for n, p in edited_model.named_parameters():
241
+ if n in pset:
242
+ if self.config.gtn.descent:
243
+ new_params.append(p - updates[n])
244
+ else:
245
+ new_params.append(p + updates[n])
246
+ else:
247
+ new_params.append(p)
248
+
249
+ edited_model.update_params(new_params)
250
+
251
+ if detach_history:
252
+ new_model = self.model_constructor()
253
+ new_model.load_state_dict(edited_model.state_dict())
254
+ edited_model = new_model
255
+
256
+ return MEND(edited_model, self.config, self.model_constructor, self.gtn, edit_lrs=self.edit_lrs), info_dict
257
+
258
+
259
+ if __name__ == '__main__':
260
+ import types
261
+
262
+ model = transformers.GPT2LMHeadModel.from_pretrained("gpt2")
263
+
264
+ config = types.SimpleNamespace()
265
+ config.model.inner_params = [
266
+ "transformer.h.9.mlp.c_fc.weight",
267
+ "transformer.h.9.mlp.c_proj.weight",
268
+ "transformer.h.10.mlp.c_fc.weight",
269
+ "transformer.h.10.mlp.c_proj.weight",
270
+ "transformer.h.11.mlp.c_fc.weight",
271
+ "transformer.h.11.mlp.c_proj.weight",
272
+ ]
273
+ config.edit_lr = 0.0001
274
+
275
+ config.gtn = types.SimpleNamespace()
276
+ config.gtn.n_hidden = 1
277
+ config.gtn = config.gtn.__dict__
278
+
279
+ gtn = MEND(model, config, lambda: copy.deepcopy(model)).cuda()
280
+ # torch.save(gtn.state_dict(), "test_state.pt")
281
+ import pdb; pdb.set_trace()
282
+ gtn.load_state_dict(torch.load("test_state.pt"))
283
+ x = torch.arange(20).view(1, 20).cuda() + 1000
284
+ orig_logits = gtn(x)
285
+ edited = gtn.edit(x, masks=torch.ones_like(x), labels=x)
286
+ post_logits = gtn(x)
287
+
288
+ assert torch.allclose(orig_logits, post_logits)
289
+
290
+ orig_param = [p for (n, p) in gtn.model.named_parameters() if n == config.model.inner_params[-1]][0]
291
+ edited_param = [p for (n, p) in edited.model.named_parameters() if n == config.model.inner_params[-1]][0]
292
+
293
+ LOG.info((orig_param - edited_param).abs().max())
294
+ edited.eval()
295
+ LOG.info(gtn(x, labels=x).loss, edited(x, labels=x).loss, edited.edit_loss_fn(edited(x).logits, x)["nll"])
296
+ edited2 = edited.edit(x, masks=torch.ones_like(x), labels=x)
297
+ LOG.info(gtn(x, labels=x).loss, edited(x, labels=x).loss, edited2(x, labels=x).loss)
algs/serac.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import copy
3
+ import transformers
4
+ import logging
5
+
6
+ from utils import scr, set_dropout, _logits, add_padding, add_sep
7
+ from editable_model import EditableModel
8
+ from models import BertClassifier
9
+
10
+ LOG = logging.getLogger(__name__)
11
+
12
+
13
+ def translate_tokens(tokens, from_tok, to_tok):
14
+ tokens = tokens.masked_fill(tokens == -100, from_tok.pad_token_id)
15
+ text = from_tok.batch_decode(tokens, skip_special_tokens=True)
16
+ return to_tok(text, return_tensors="pt")["input_ids"].to(tokens.device)
17
+
18
+
19
+ class SERAC(EditableModel):
20
+ def __init__(self, model, config, model_constructor, classifier=None, classifier_tok=None,
21
+ replacement=None, replacement_tok=None, cache_inputs=None, cache_labels=None,
22
+ cache_embeds=None, scale=None):
23
+ super().__init__(model, config, model_constructor)
24
+
25
+ if classifier is None:
26
+ if config.rep.cross_attend and not config.rep.cls_class.endswith("ForSequenceClassification"):
27
+ LOG.warn(f"Switching {config.rep.cls_class} to {config.rep.cls_class}ForSequenceClassification for cross-attend")
28
+ config.rep.cls_class += "ForSequenceClassification"
29
+ self.classifier = getattr(transformers, config.rep.cls_class).from_pretrained(config.rep.cls_name, cache_dir=scr())
30
+ if self.config.rep.checkpoint_grad:
31
+ LOG.info(f"Checking for checkpointing: {hasattr(self.classifier.config, 'gradient_checkpointing')}")
32
+ self.classifier.config.gradient_checkpointing = True
33
+ self.classifier_tok = transformers.AutoTokenizer.from_pretrained(config.rep.cls_name, cache_dir=scr())
34
+ if not self.config.rep.cross_attend and 'bert' in self.config.rep.cls_name:
35
+ self.classifier.pooler = None # we don't need the classification head
36
+ elif not self.config.rep.cross_attend and "mpnet" not in self.config.rep.cls_name:
37
+ if hasattr(self.classifier, "pooler"):
38
+ self.classifier.pooler = None # we don't need the classification head
39
+
40
+ set_dropout(self.classifier, config.dropout)
41
+ if self.config.rep.lora is not None:
42
+ self.classifier = LoraModel(self.classifier, self.config.rep.lora)
43
+ else:
44
+ assert isinstance(classifier, torch.nn.Module), f"Classifier is a {type(classifier)}!"
45
+ assert isinstance(classifier_tok, transformers.PreTrainedTokenizerBase), f"Classifier tok is {type(classifier_tok)}!"
46
+ self.classifier, self.classifier_tok = classifier, classifier_tok
47
+
48
+ if replacement is None:
49
+ # self.replacement_tok = getattr(transformers, config.model.tokenizer_class).from_pretrained(config.model.tokenizer_name,
50
+ # cache_dir=scr())
51
+ self.replacement_tok = transformers.AutoTokenizer.from_pretrained(config.model.small_name, cache_dir=scr())
52
+ # if self.replacement_tok.sep_token is None:
53
+ # self.replacement_tok.sep_token = self.replacement_tok.eos_token
54
+ if (False and self.config.rep.freeze_cntr):
55
+ self.replacement = None
56
+ else:
57
+ if config.model.class_name == "BertClassifier":
58
+ self.replacement = BertClassifier(config.model.small_name)
59
+ else:
60
+ self.replacement = getattr(transformers, config.model.class_name).from_pretrained(config.model.small_name, cache_dir=scr())
61
+ if self.replacement_tok.sep_token is None and "gpt" not in self.model.name_or_path.lower():
62
+ add_sep(self.replacement_tok, self.replacement)
63
+ if self.replacement_tok.pad_token is None:
64
+ add_padding(self.replacement_tok, self.replacement)
65
+ set_dropout(self.replacement, config.dropout)
66
+ else:
67
+ assert isinstance(replacement, torch.nn.Module), "Rep is {type(replacement)}!"
68
+ assert isinstance(replacement_tok, transformers.PreTrainedTokenizerBase), "Rep tok is {type(replacement_tok)}!"
69
+ self.replacement, self.replacement_tok = replacement, replacement_tok
70
+
71
+ if self.config.rep.cross_attend:
72
+ self.scale = None
73
+ else:
74
+ if scale is None:
75
+ self.register_buffer("scale", torch.tensor(1.0))
76
+ # self.scale = nn.Parameter(torch.tensor(1.0))
77
+ else:
78
+ self.scale = scale
79
+
80
+ if cache_inputs is None:
81
+ self.cache_inputs = []
82
+ self.cache_labels = []
83
+ if config.rep.cache_embeds and not config.rep.cross_attend:
84
+ self.cache_embeds = {}
85
+ else:
86
+ assert isinstance(cache_inputs, list), f"Cache inputs is {cache_inputs}"
87
+ assert isinstance(cache_labels, list), f"Cache labels is {cache_labels}"
88
+ self.cache_inputs = copy.deepcopy(cache_inputs)
89
+ self.cache_labels = copy.deepcopy(cache_labels)
90
+ if config.rep.cache_embeds and not config.rep.cross_attend:
91
+ assert isinstance(cache_embeds, dict), f"Cache embeds is {cache_embeds}"
92
+ self.cache_embeds = copy.deepcopy(cache_embeds)
93
+
94
+ def state_dict(self, destination=None, prefix="", keep_vars=False):
95
+ state_dict = super().state_dict(prefix=prefix, keep_vars=keep_vars) # Get default state dict
96
+ model_keys = self.model.state_dict(prefix=prefix, keep_vars=keep_vars).keys() # Remove model params
97
+ for k in model_keys:
98
+ del state_dict[f"model.{k}"]
99
+ if self.config.rep.freeze_cntr:
100
+ cntr_keys = self.replacement.state_dict().keys()
101
+ for k in cntr_keys:
102
+ del state_dict[f"replacement.{k}"]
103
+ state_dict["model_config"] = self.model.config # Include model config
104
+ return state_dict
105
+
106
+ def load_state_dict(self, state_dict, strict: bool = True):
107
+ config = state_dict["model_config"]
108
+ del state_dict["model_config"]
109
+ if config != self.model.config:
110
+ LOG.info("Loaded model config doesn't match current model config.")
111
+ LOG.info(f"Loaded: {config}")
112
+ LOG.info(f"Current: {self.model.config}")
113
+
114
+ if (False and self.config.rep.freeze_cntr):
115
+ rep_keys = list(state_dict.keys())
116
+ for k in rep_keys:
117
+ if k.startswith("replacement"):
118
+ del state_dict[k]
119
+ res = super().load_state_dict(state_dict, False)
120
+ else:
121
+ try:
122
+ res = super().load_state_dict(state_dict, False)
123
+ except RuntimeError:
124
+ LOG.info("Load failed; trying again without loading counterfactual model weights.")
125
+ rep_keys = list(state_dict.keys())
126
+ for k in rep_keys:
127
+ if k.startswith("replacement"):
128
+ del state_dict[k]
129
+ res = super().load_state_dict(state_dict, False)
130
+
131
+ # We should only have missing keys for the model, and no unexpected keys
132
+ def ok_to_miss(k):
133
+ return k.startswith("model.") or ((False and self.config.rep.freeze_cntr) and k.startswith("replacement."))
134
+ missing_keys = [k for k in res.missing_keys if not ok_to_miss(k)]
135
+ assert len(missing_keys) == 0, f"Should only have missing keys for model: {missing_keys}."
136
+ assert len(res.unexpected_keys) == 0, "Shouldn't have any unexpected keys"
137
+ return res
138
+
139
+ def outer_parameters(self, grouped=False):
140
+ if self.config.rep.freeze is not None:
141
+ modlist = None
142
+ for m in self.classifier.modules():
143
+ if isinstance(m, torch.nn.ModuleList):
144
+ modlist = m
145
+ break
146
+ model_params = list(modlist[-self.config.rep.freeze:].parameters())
147
+ else:
148
+ model_params = list(self.classifier.parameters())
149
+
150
+ if self.config.rep.lora is not None or self.config.rep.freeze is not None:
151
+ cls = self.classifier.base_model if self.config.rep.lora else self.classifier
152
+ if hasattr(cls, "classifier"):
153
+ model_params.extend(cls.classifier.parameters())
154
+ if hasattr(cls, "pre_classifier"):
155
+ model_params.extend(cls.pre_classifier.parameters())
156
+
157
+ if not (False and self.config.rep.freeze_cntr):
158
+ model_params.extend(list(self.replacement.parameters()))
159
+
160
+ extra_params = []
161
+ if grouped:
162
+ return [
163
+ dict(params=model_params, lr=self.config.lr),
164
+ dict(params=extra_params, lr=self.config.lr_lr)
165
+ ]
166
+ else:
167
+ return model_params + extra_params
168
+
169
+ def edit(self, batch, condition=None, detach_history=False):
170
+ def detokenize(toks, tok):
171
+ tokens = toks.masked_fill(toks == -100, tok.pad_token_id)
172
+ return tok.batch_decode(tokens, skip_special_tokens=True)
173
+
174
+ inputs = detokenize(batch["input_ids"], self.replacement_tok)
175
+ if "bert" in self.config.model.name:
176
+ labels = ["" for _ in batch["labels"]]
177
+ else:
178
+ labels = detokenize(batch["labels"], self.replacement_tok)
179
+
180
+ cache_inputs = self.cache_inputs + inputs
181
+ cache_labels = self.cache_labels + labels
182
+
183
+ if self.config.rep.cache_embeds and not self.config.rep.cross_attend:
184
+ cls_inputs = self.build_cls_cache_inputs(inputs, labels)
185
+ with torch.no_grad():
186
+ embeds = self.compute_cls_embeddings(cls_inputs)
187
+
188
+ cache_embeds = {inp: emb for inp, emb in zip(cls_inputs, embeds)}
189
+ cache_embeds.update(self.cache_embeds)
190
+ else:
191
+ cache_embeds = None
192
+
193
+ new_model = SERAC(self.model, self.config, self.model_constructor, self.classifier, self.classifier_tok,
194
+ self.replacement, self.replacement_tok, cache_inputs, cache_labels, cache_embeds, self.scale)
195
+ new_model.train(self.training)
196
+ return new_model, {}
197
+
198
+ def stats(self):
199
+ return self.last_stats
200
+
201
+ def compute_cls_embeddings(self, text):
202
+ inputs = self.classifier_tok(text, return_tensors="pt", padding=True).to(self.config.device)
203
+ if 'bert' in self.config.rep.cls_name:
204
+ embeds = self.classifier(**inputs).last_hidden_state[:, 0].unsqueeze(1)
205
+ else:
206
+ embeds = self.classifier(**inputs).pooler_output.unsqueeze(1)
207
+ embeds = embeds.view(embeds.shape[0], self.config.rep.dist_heads, -1)
208
+ if self.config.rep.bound_embeds:
209
+ embeds = embeds.tanh()
210
+ return embeds
211
+
212
+ def embedding_logsim_matrix(self, cls_ctxs, test_input_text):
213
+ if self.config.rep.cache_embeds and not self.config.rep.cross_attend and not self.training:
214
+ ctx_embeds = torch.cat([self.cache_embeds[ctx] for ctx in cls_ctxs])
215
+ else:
216
+ ctx_embeds = self.compute_cls_embeddings(cls_ctxs)
217
+ main_embeds = self.compute_cls_embeddings(test_input_text)
218
+
219
+ if self.config.rep.cos:
220
+ cos = (ctx_embeds[None] * main_embeds[:, None]).sum(-1) / (ctx_embeds[None].norm(2, -1) * main_embeds[:, None].norm(2, -1))
221
+ dists = 1 - cos
222
+ else:
223
+ dists = (ctx_embeds[None] - main_embeds[:, None]).norm(2, -1)
224
+ if self.config.rep.square:
225
+ dists = dists ** 2
226
+
227
+ dists = dists.min(-1).values # get rid of the dists head dimension
228
+
229
+ assert dists.min() >= 0, "Shouldn't have negative distances!"
230
+ cls_logsims = -dists * self.scale
231
+
232
+ return cls_logsims
233
+
234
+ def crossattend_logsim_matrix(self, cls_ctxs, test_input_texts):
235
+ batch = [ctx + self.classifier_tok.sep_token + test for test in test_input_texts for ctx in cls_ctxs]
236
+ batch_toks = self.classifier_tok(batch, return_tensors="pt", padding=True).to(self.config.device)
237
+ batch_logsims = self.classifier(**batch_toks).logits.log_softmax(-1)[:, 0]
238
+ logsim_matrix = batch_logsims.view(len(test_input_texts), len(cls_ctxs))
239
+
240
+ return logsim_matrix
241
+
242
+ def build_rep_cache_contexts(self):
243
+ sep = " "
244
+ if hasattr(self.model, "name_or_path") and "gpt" in self.model.name_or_path.lower():
245
+ # The labels are include in the inputs for autoregressive models. Cut off the label for the classifier
246
+ ctxs = [cin + sep for cin in self.cache_inputs]
247
+ else:
248
+ ctxs = [cin + sep + clab + sep for cin, clab in zip(self.cache_inputs, self.cache_labels)]
249
+ return ctxs
250
+
251
+ def build_cls_cache_inputs(self, cache_inputs=None, cache_labels=None):
252
+ sep = self.classifier_tok.sep_token
253
+ if cache_inputs is None:
254
+ cache_inputs = self.cache_inputs
255
+ if cache_labels is None:
256
+ cache_labels = self.cache_labels
257
+
258
+ if hasattr(self.model, "name_or_path") and "gpt" in self.model.name_or_path.lower():
259
+ # The labels are include in the inputs for autoregressive models. Cut off the label for the classifier
260
+ inputs = [cin.rsplit(" ", 1)[0] + sep for cin in cache_inputs]
261
+ else:
262
+ inputs = [cin + sep + clab + sep for cin, clab in zip(cache_inputs, cache_labels)]
263
+ return inputs
264
+
265
+ def build_rep_input_tokens(self, kwargs, idxs, generation=False):
266
+ assert len(idxs) == len(kwargs["input_ids"]), "Need one cache idx for each test input"
267
+ cache_contexts = self.build_rep_cache_contexts()
268
+ selected_contexts = [cache_contexts[idx.item()] for idx in idxs]
269
+ test_inputs = self.replacement_tok.batch_decode(kwargs["input_ids"], skip_special_tokens=True)
270
+ rep_texts = [ctx + inp for ctx, inp in zip(selected_contexts, test_inputs)]
271
+ rep_input_tokens = self.replacement_tok(rep_texts, return_tensors="pt", padding=True).to(self.config.device)
272
+
273
+ rep_kwargs = {
274
+ "input_ids": rep_input_tokens["input_ids"],
275
+ "attention_mask": rep_input_tokens["attention_mask"],
276
+ }
277
+
278
+ if not generation:
279
+ rep_kwargs["labels"] = kwargs["labels"]
280
+
281
+ # if self.config.task in ["fc", "fnli"]:
282
+ # del rep_kwargs["labels"]
283
+
284
+ if hasattr(self.model, "name_or_path") and "gpt" in self.model.name_or_path.lower():
285
+ # Add 'ignore' labels for the prepended cache inputs
286
+ pre = torch.full((kwargs["labels"].shape[0], rep_kwargs["input_ids"].shape[-1] - kwargs["labels"].shape[-1]), -100,
287
+ device=kwargs["labels"].device)
288
+ rep_kwargs["labels"] = torch.cat((pre, kwargs["labels"]), dim=-1)
289
+
290
+ return rep_kwargs
291
+
292
+ def run_classifier(self, *inputs, **kwargs):
293
+ cache_inputs = self.build_cls_cache_inputs()
294
+ test_inputs = self.replacement_tok.batch_decode(kwargs["input_ids"], skip_special_tokens=True)
295
+
296
+ if self.config.rep.cross_attend:
297
+ log_sim_matrix = self.crossattend_logsim_matrix(cache_inputs, test_inputs)
298
+ else:
299
+ log_sim_matrix = self.embedding_logsim_matrix(cache_inputs, test_inputs)
300
+
301
+ sims = log_sim_matrix.exp()
302
+ assert sims.max() <= 1, "Similarities shouldn't exceed 1!"
303
+
304
+ cls_sims, cls_idxs = sims.max(-1)
305
+ return cls_sims, cls_idxs, log_sim_matrix
306
+
307
+ def generate(self, *args, **kwargs):
308
+ # input_text = self.replacement_tok.batch_decode(kwargs["input_ids"], skip_special_tokens=True)
309
+ base_generate_fn = (
310
+ self.model.forward if type(self.model) == BertClassifier
311
+ else lambda *args, **kwargs: self.model.generate(*args, **kwargs, max_new_tokens=20)
312
+ )
313
+ cntr_generate_fn = (
314
+ self.replacement.forward if type(self.replacement) == BertClassifier
315
+ else lambda *args, **kwargs: self.replacement.generate(*args, **kwargs, max_new_tokens=20)
316
+ )
317
+
318
+ # assert len(args) == 0, "Should only pass named arguments to generate()"
319
+ if len(self.cache_inputs) > 0:
320
+ override = kwargs.get("override")
321
+ if override:
322
+ del kwargs["override"]
323
+
324
+ cls_sims, cls_idxs, _ = self.run_classifier(*args, **kwargs)
325
+ # assert cls_sims.numel() == 1
326
+ # print(f"Cache score: {cls_sims.item()} " + ("[MISS]" if cls_sims.item() < 0.5 else "[HIT]"))
327
+ use_cntr = (override == "cntr") if override is not None else (cls_sims.item() > 0.5)
328
+ if use_cntr:
329
+ rep_input = self.build_rep_input_tokens(kwargs, cls_idxs, generation=True)
330
+ kwargs["input_ids"] = rep_input["input_ids"]
331
+ kwargs["attention_mask"] = rep_input["attention_mask"]
332
+ # rep_input_text = self.replacement_tok.decode(rep_input["input_ids"][0])
333
+ # print(f"Returning counterfactual model output for '{rep_input_text}'")
334
+ if self.config.rep.freeze_cntr:
335
+ return base_generate_fn(*args, **kwargs)
336
+ else:
337
+ return cntr_generate_fn(*args, **kwargs)
338
+
339
+ # print(f"Returning base model output for '{input_text}'")
340
+ return base_generate_fn(*args, **kwargs)
341
+
342
+ def forward(self, *inputs, return_logits_only=True, eps=torch.finfo(torch.float32).eps, pos_pairs=None, **kwargs):
343
+ grad_enabled = torch.is_grad_enabled()
344
+ torch.set_grad_enabled(self.training)
345
+
346
+ # need to do soft mixing of logits if we're doing supervised training or we've specifically requested it
347
+ soft = (not self.config.rep.supervised) or self.config.rep.soft_weighting
348
+ with torch.no_grad():
349
+ if len(self.cache_inputs) == 0:
350
+ super_out = super().forward(*inputs, **kwargs).float()
351
+ torch.set_grad_enabled(grad_enabled)
352
+ return super_out
353
+ else:
354
+ base_logits = super().forward(*inputs, **kwargs).float()
355
+ if soft:
356
+ if base_logits.dim() == 3:
357
+ base_probs = base_logits.softmax(-1)
358
+ else:
359
+ base_probs = base_logits.sigmoid()
360
+ del base_logits
361
+
362
+ cls_sims, cls_idxs, cls_logits = self.run_classifier(*inputs, **kwargs)
363
+ rep_cls_inputs = self.build_rep_input_tokens(kwargs, cls_idxs)
364
+ if self.config.rep.freeze_cntr:
365
+ rep_cls_logits = _logits(super().forward(**rep_cls_inputs))
366
+ else:
367
+ rep_cls_logits = _logits(self.replacement(**rep_cls_inputs))
368
+
369
+ if pos_pairs is not None:
370
+ assert (pos_pairs[:, 0] == torch.arange(pos_pairs.shape[0], device=pos_pairs.device)).all()
371
+ gold_idxs = pos_pairs[:, 1]
372
+ # print("IDX acc:", (cls_idxs == gold_idxs).shape, (cls_idxs == gold_idxs).float().mean())
373
+ rep_gold_inputs = self.build_rep_input_tokens(kwargs, gold_idxs)
374
+ if (False and self.config.rep.freeze_cntr):
375
+ rep_gold_logits = _logits(super().forward(**rep_gold_inputs))
376
+ else:
377
+ rep_gold_logits = _logits(self.replacement(**rep_gold_inputs))
378
+ else:
379
+ rep_gold_logits = rep_cls_logits
380
+
381
+ cls_sims = cls_sims.view(-1, 1) # For (binary) classification, predictions are (B x 1)
382
+ if rep_cls_logits.dim() == 3:
383
+ cls_sims.unsqueeze_(-1) # For generation/seq2seq, predictions are (B x S x V)
384
+
385
+ stats = {
386
+ 'sims/mean': cls_sims.mean().item(),
387
+ 'sims/pos': (cls_sims >= 0.5).float().mean().item(),
388
+ 'sims/neg': (cls_sims < 0.5).float().mean().item(),
389
+ 'params/scale': self.scale.item() if self.scale is not None else 0.0,
390
+ }
391
+
392
+ if hasattr(self.model, "name_or_path") and "gpt" in self.model.name_or_path.lower():
393
+ rep_cls_logits = rep_cls_logits[:, -kwargs["labels"].shape[-1]:, :]
394
+
395
+ if soft:
396
+ rep_weight = cls_sims
397
+ if base_probs.dim() == 3:
398
+ mixture_logits = ((1 - rep_weight) * base_probs + rep_weight * rep_cls_logits.softmax(-1) + eps).log()
399
+ else:
400
+ mixture_logits = ((1 - rep_weight) * base_probs + rep_weight * rep_cls_logits.sigmoid() + eps).log()
401
+ else:
402
+ rep_idxs = torch.where(cls_sims > 0.5)[0]
403
+ mixture_logits = base_logits
404
+ if rep_idxs.numel() > 0:
405
+ mixture_logits[rep_idxs] = rep_cls_logits[rep_idxs]
406
+
407
+ torch.set_grad_enabled(grad_enabled)
408
+ if return_logits_only:
409
+ return mixture_logits
410
+ else:
411
+ return mixture_logits, cls_logits, rep_gold_logits, stats
412
+
413
+
414
+ if __name__ == '__main__':
415
+ import types
416
+
417
+ model = transformers.GPT2LMHeadModel.from_pretrained("gpt2")
418
+
419
+ config = types.SimpleNamespace()
420
+ config.model.inner_params = [
421
+ "transformer.h.9.mlp.c_fc.weight",
422
+ "transformer.h.9.mlp.c_proj.weight",
423
+ "transformer.h.10.mlp.c_fc.weight",
424
+ "transformer.h.10.mlp.c_proj.weight",
425
+ "transformer.h.11.mlp.c_fc.weight",
426
+ "transformer.h.11.mlp.c_proj.weight",
427
+ ]
428
+ config.edit_lr = 0.0001
429
+
430
+ config.gtn = types.SimpleNamespace()
431
+ config.gtn.n_hidden = 1
432
+ config.gtn = config.gtn.__dict__
433
+
434
+ gtn = SERAC(model, config, lambda: copy.deepcopy(model)).cuda()
435
+ # torch.save(gtn.state_dict(), "test_state.pt")
436
+ import pdb; pdb.set_trace()
437
+ gtn.load_state_dict(torch.load("test_state.pt"))
438
+ x = torch.arange(20).view(1, 20).cuda() + 1000
439
+ orig_logits = gtn(x)
440
+ edited = gtn.edit(x, masks=torch.ones_like(x), labels=x)
441
+ post_logits = gtn(x)
442
+
443
+ assert torch.allclose(orig_logits, post_logits)
444
+
445
+ orig_param = [p for (n, p) in gtn.model.named_parameters() if n == config.model.inner_params[-1]][0]
446
+ edited_param = [p for (n, p) in edited.model.named_parameters() if n == config.model.inner_params[-1]][0]
447
+
448
+ LOG.info((orig_param - edited_param).abs().max())
449
+ edited.eval()
450
+ LOG.info(gtn(x, labels=x).loss, edited(x, labels=x).loss, edited.edit_loss_fn(edited(x).logits, x)["nll"])
451
+ edited2 = edited.edit(x, masks=torch.ones_like(x), labels=x)
452
+ LOG.info(gtn(x, labels=x).loss, edited(x, labels=x).loss, edited2(x, labels=x).loss)
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import streamlit as st
2
  import pandas as pd
3
  import time
 
4
 
5
  EDIT_ALGS = [
6
  "MEND: Model editor networks using gradient decomposition",
 
1
  import streamlit as st
2
  import pandas as pd
3
  import time
4
+ import algs
5
 
6
  EDIT_ALGS = [
7
  "MEND: Model editor networks using gradient decomposition",
editable_model.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from losses import masked_log_probs
4
+ from utils import _logits, shift_targets
5
+
6
+
7
+ class EditableModel(nn.Module):
8
+ def __init__(self, model, config, model_constructor):
9
+ super().__init__()
10
+
11
+ self.model = model
12
+ self.config = config
13
+ self.model_constructor = model_constructor
14
+
15
+ def _edit_loss_fn(pred, targ, **kwargs):
16
+ return masked_log_probs(pred, targ, shift=shift_targets(self.config), **kwargs)
17
+ self.edit_loss_fn = _edit_loss_fn
18
+ self.loc_loss_fn = _edit_loss_fn
19
+
20
+ def edit(self, batch, condition=None, detach_history=False):
21
+ raise NotImplementedError
22
+
23
+ def forward(self, *inputs, **kwargs):
24
+ return _logits(self.model(*inputs, **kwargs))
25
+
26
+ def outer_parameters(self, grouped=False):
27
+ if grouped:
28
+ return [dict(params=self.parameters(), lr=self.config.lr)]
29
+ else:
30
+ return list(self.parameters())
31
+
32
+ def generate(self, *args, **kwargs):
33
+ return self.model.generate(*args, **kwargs)
34
+
35
+ def base_loss(self, input_ids, attention_masks, label_ids):
36
+ pass
hooks.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import parent_module
2
+
3
+
4
+ def linear_backward_hook(mod, grad_in, grad_out):
5
+ if not hasattr(mod, "weight"):
6
+ print(f"{mod} has no weight!")
7
+ return
8
+
9
+ if hasattr(mod.weight, "__x__"):
10
+ assert len(grad_out) == 1
11
+ # mod.weight.__bgrad__ = grad_out[0].unsqueeze(-1) * mod.__x__[0].unsqueeze(-2)
12
+ mod.weight.__delta__ = grad_out[0].detach()
13
+ else:
14
+ print(f"{mod} has no __x__")
15
+
16
+
17
+ def linear_forward_hook(mod, activations, output):
18
+ assert len(activations) == 1
19
+ mod.weight.__x__ = activations[0].detach()
20
+
21
+
22
+ def hook_model(model, pnames):
23
+ handles = []
24
+ for m in [parent_module(model, pname) for pname in pnames]:
25
+ handles.append(m.register_full_backward_hook(linear_backward_hook))
26
+ handles.append(m.register_forward_hook(linear_forward_hook))
27
+
28
+ model.handles = handles
losses.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from metrics import es_sentiment
4
+ from utils import gather_log_probs, mask_hf_labels, masked_mean
5
+
6
+
7
+ def balanced_bce(log_probs, labels, eps=torch.finfo(torch.float32).eps):
8
+ assert labels.max() <= 1
9
+ assert labels.min() >= 0
10
+
11
+ pos_losses = -log_probs[labels == 1]
12
+ neg_probs = 1 - log_probs.exp()
13
+ neg_probs[neg_probs == 0] += eps # for numerical stability
14
+ neg_losses = -neg_probs.log()[labels == 0]
15
+ pos_loss = pos_losses.mean() if pos_losses.numel() > 0 else 0
16
+ neg_loss = neg_losses.mean() if neg_losses.numel() > 0 else 0
17
+
18
+ return pos_loss + neg_loss
19
+
20
+
21
+ def kl_loc_loss(pre, post, mask=None):
22
+ pre = pre.to(torch.float32)
23
+ post = post.to(torch.float32)
24
+
25
+ sequence = pre.dim() == 3
26
+ pre_ = pre.view(-1, pre.shape[-1])
27
+ post_ = post.view(pre_.shape)
28
+ assert pre_.shape[0] == post_.shape[0]
29
+
30
+ if not sequence:
31
+ if pre_.shape[-1] == 1: # No masking needed for binary classification
32
+ return (pre.sigmoid() * (F.logsigmoid(pre) - F.logsigmoid(post))).mean() + (
33
+ (-pre).sigmoid() * (F.logsigmoid(-pre) - F.logsigmoid(-post))
34
+ ).mean()
35
+ else: # We have sequences of predictions; masking needed
36
+ if pre_.shape[-1] > 1:
37
+ assert mask is not None
38
+ mask_ = mask.view(pre_.shape[0])
39
+ kl = (pre_.softmax(-1) * (pre_.log_softmax(-1) - post_.log_softmax(-1))).sum(-1)
40
+ return (kl * mask_).sum() / mask_.sum()
41
+
42
+ raise NotImplementedError
43
+
44
+
45
+ def binary_log_probs(pred, targ, should_reduce=True):
46
+ assert targ.max() <= 1
47
+ assert targ.min() >= 0
48
+ neg_mask = torch.ones_like(pred)
49
+ neg_mask[targ == 0] *= -1
50
+ pred = pred * neg_mask
51
+ log_probs = F.logsigmoid(pred)
52
+ acc = (log_probs.exp() > 0.5).float()
53
+ if should_reduce:
54
+ acc = acc.mean()
55
+ return {
56
+ "acc": acc,
57
+ "log_prob": log_probs.mean(),
58
+ "prob": log_probs.exp().mean(),
59
+ "nll": -log_probs.mean(),
60
+ "n_tokens": log_probs.shape[0]
61
+ }
62
+
63
+
64
+ def multiclass_log_probs(
65
+ pred,
66
+ raw_targets,
67
+ shift=True,
68
+ eps=torch.finfo(torch.float32).eps,
69
+ should_reduce=True,
70
+ **kwargs,
71
+ ):
72
+ NULL_TOKEN = 0 # a placeholder used for masked target locations
73
+
74
+ pred = pred.clone()
75
+ mask, targ = mask_hf_labels(raw_targets)
76
+ if shift and pred.dim() == 3: # Dealing with sequences
77
+ pred = pred[:, :-1] # Remove last prediction in sequence
78
+ targ = targ[:, 1:] # Shift to align predictions and targets
79
+
80
+ unmasked_log_probs = gather_log_probs(pred, targ)
81
+
82
+ pred_ids = pred.argmax(-1).masked_fill(~mask, NULL_TOKEN)
83
+ correct = pred_ids == targ
84
+ if pred.dim() == 3:
85
+ correct = (pred_ids == targ).all(-1) # We want to get the whole sequence right
86
+ acc = correct.float()
87
+ if should_reduce:
88
+ acc = acc.mean()
89
+
90
+ if "inner_sent" in kwargs:
91
+ # Only use outer samples with the same sentiment as the inner sample
92
+ same_sent_mask = torch.tensor([i == o for i, o in zip(kwargs["inner_sent"], kwargs["outer_sent"])], device=pred.device)
93
+ good_mask = mask * same_sent_mask.unsqueeze(-1)
94
+ bad_mask = mask * (~same_sent_mask.unsqueeze(-1))
95
+
96
+ good_log_prob = masked_mean(unmasked_log_probs, good_mask)
97
+ bad_log_prob = masked_mean((1 - unmasked_log_probs.exp() + eps).log(), bad_mask)
98
+
99
+ n_tokens = good_mask.float().sum()
100
+ avg_log_prob = good_log_prob
101
+
102
+ if kwargs["unlikelihood"]:
103
+ nll = -good_log_prob - bad_log_prob
104
+ else:
105
+ nll = -good_log_prob
106
+ else:
107
+ n_tokens = mask.float().sum()
108
+ avg_log_prob = (unmasked_log_probs * mask.float()).sum() / n_tokens
109
+ nll = -avg_log_prob
110
+
111
+ info_dict = {
112
+ "acc": acc,
113
+ "log_prob": avg_log_prob,
114
+ "prob": avg_log_prob.exp(),
115
+ "n_tokens": n_tokens,
116
+ "nll": nll
117
+ }
118
+
119
+ if "inner_sent" in kwargs:
120
+ info_dict.update(es_sentiment(kwargs["pre_edit_logits"],
121
+ kwargs["post_edit_logits"],
122
+ raw_targets,
123
+ same_sent_mask))
124
+
125
+ return info_dict
126
+
127
+
128
+ def masked_log_probs(pred, targ, shift=True, **kwargs):
129
+ pred = pred.to(torch.float32)
130
+
131
+ if not (pred.dim() == 2 or pred.dim() == 3):
132
+ raise RuntimeError(f"Expected pred to have 2 or 3 dimensions, got {pred.shape}")
133
+
134
+ if pred.shape[-1] == 1:
135
+ should_reduce = True
136
+ if "should_reduce" in kwargs:
137
+ should_reduce = kwargs["should_reduce"]
138
+ return binary_log_probs(pred, targ, should_reduce=should_reduce)
139
+ else:
140
+ return multiclass_log_probs(pred, targ, shift=shift, **kwargs)
141
+
142
+
143
+ def test_masked_log_probs():
144
+ print()
145
+ N = 10000
146
+ pred = torch.randn(10, 15, N)
147
+ targ = torch.randint(0, N, (10, 15))
148
+ true_pred = pred.clone()
149
+ true_pred.scatter_(2, targ.unsqueeze(-1), 5)
150
+ true_pred = true_pred.roll(-1, 1)
151
+
152
+ half_pred = true_pred.clone()
153
+ mask = torch.arange(10) % 2 == 0
154
+ half_pred[mask] = pred[mask]
155
+
156
+ pred_ = pred.clone()
157
+ true_pred_ = true_pred.clone()
158
+ half_pred_ = half_pred.clone()
159
+ targ_ = targ.clone()
160
+
161
+ print(masked_log_probs(pred, targ, return_acc=True))
162
+ print(masked_log_probs(true_pred, targ, return_acc=True))
163
+ print(masked_log_probs(half_pred, targ, return_acc=True))
164
+
165
+ assert (pred == pred_).all()
166
+ assert (targ == targ_).all()
167
+ assert (half_pred == half_pred_).all()
168
+ assert (true_pred == true_pred_).all()
169
+
170
+ import pdb; pdb.set_trace()
171
+
172
+ pred = torch.randn(1000, 15, 1)
173
+ targ = torch.randint(0, 2, (1000, 15))
174
+
175
+ print(masked_log_probs(pred, targ, return_acc=True))
176
+
177
+
178
+ if __name__ == "__main__":
179
+ torch.manual_seed(0)
180
+
181
+ test_masked_log_probs()
metrics.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from utils import gather_log_probs, mask_hf_labels, masked_mean
3
+
4
+
5
+ def es_sentiment(pre_logits, post_logits, raw_targets, same_sent_mask, NULL_TOKEN=0):
6
+ with torch.no_grad():
7
+ mask, targ = mask_hf_labels(raw_targets)
8
+ pos_mask = same_sent_mask.unsqueeze(-1) * mask
9
+ neg_mask = (~same_sent_mask).unsqueeze(-1) * mask
10
+
11
+ # Compute log likelihoods of pos/neg samples
12
+ pre_edit_token_log_probs = gather_log_probs(pre_logits, targ)
13
+ post_edit_token_log_probs = gather_log_probs(post_logits, targ)
14
+
15
+ mean_pos_pre = masked_mean(pre_edit_token_log_probs, pos_mask)
16
+ mean_pos_post = masked_mean(post_edit_token_log_probs, pos_mask)
17
+ mean_neg_post = masked_mean(post_edit_token_log_probs, neg_mask)
18
+
19
+ z_sent = (mean_pos_post - mean_neg_post).sigmoid()
20
+ z_topic_raw = (mean_pos_post - mean_pos_pre).exp()
21
+ z_topic = min(1, z_topic_raw)
22
+
23
+ es_sent = z_sent * z_topic
24
+
25
+ return {
26
+ "acc_sent": es_sent,
27
+ "z_sent": z_sent,
28
+ "z_topic": z_topic,
29
+ "z_topic_raw": z_topic_raw,
30
+ "correct_probs": mean_pos_post,
31
+ "wrong_probs": mean_neg_post,
32
+ }
33
+
34
+
35
+ # DEPRECATED
36
+ def sent_success(pre_edit_probs, post_edit_probs, pos_mask, eps=torch.finfo(torch.float32).eps, batch_size=20):
37
+ assert False, "No longer used"
38
+ # content_score = post_edit_probs[pos_mask].prod() ** (1/pos_mask.sum()) / (pre_edit_probs[pos_mask]. + eps)
39
+ post_pos_avg = post_edit_probs[pos_mask].prod() ** (1 / pos_mask.sum())
40
+ pre_pos_avg = pre_edit_probs[pos_mask].prod() ** (1 / pos_mask.sum())
41
+ content_score = post_pos_avg / (pre_pos_avg + eps)
42
+ z_content = min(1., content_score)
43
+
44
+ # compute z_sent through a weighting objective
45
+ # normalized_probs = post_edit_probs / (post_edit_probs.sum() + eps)
46
+ # balancing_factor = 0.5 * ((~pos_mask).float().sum() / pos_mask.float().sum() + 1)
47
+ # z_sent_weight = balancing_factor * normalized_probs.dot(pos_mask.float())
48
+ post_neg_avg = post_edit_probs[~pos_mask].prod() ** (1 / (~pos_mask).sum())
49
+ neg_over_pos = post_neg_avg / (eps + post_pos_avg)
50
+ z_sent_weight = 1 / (1 + neg_over_pos)
51
+
52
+ # compute z_sent through a ranking objective
53
+ batch_mask = pos_mask.view(-1, batch_size).long()
54
+ sort_idxs = post_edit_probs.view(-1, batch_size).sort(-1, descending=True).indices
55
+ ranked_mask = batch_mask.gather(1, sort_idxs)
56
+ true_mask = batch_mask.sort(-1, descending=True).values
57
+ z_sent_rank = (ranked_mask == true_mask).float().mean()
58
+
59
+ # compute the final success scores
60
+ weight_success = (z_content * z_sent_weight) ** 0.5
61
+ rank_success = (z_content * z_sent_rank) ** 0.5
62
+
63
+ correct_probs = post_edit_probs[pos_mask].mean()
64
+ wrong_probs = post_edit_probs[~pos_mask].mean()
65
+
66
+ return {
67
+ "acc_weight": weight_success,
68
+ "acc_rank": rank_success,
69
+ "rank_score": z_sent_rank,
70
+ "weight_score": z_sent_weight,
71
+ "content_score": content_score,
72
+ "post_edit_probs": post_edit_probs.sum(),
73
+ "pre_edit_probs": pre_edit_probs.sum(),
74
+ "correct_probs": correct_probs,
75
+ "wrong_probs": wrong_probs
76
+ }
77
+
78
+
79
+ # def sent_retain(pre_logits, post_logits, sent_mask, batch_size=20, eps=torch.finfo(torch.float32).eps):
80
+ # pre_log_probs = pre_logits.log_softmax(-1).gather(-1, all_targ.unsqueeze(-1)).squeeze(-1)
81
+ # post_log_probs = post_logits.log_softmax(-1).gather(-1, all_targ.unsqueeze(-1)).squeeze(-1)
82
+
83
+ # pre_batch = pre_probs.view(-1, batch_size)
84
+ # post_batch = post_probs.view(-1, batch_size)
85
+ # mask_batch = sent_mask.view(-1, batch_size)
86
+
87
+ # stats = []
88
+ # for pre, post, mask in zip(pre_batch, post_batch, mask_batch):
89
+ # avg_pre = pre.prod() ** (1 / pre.numel())
90
+ # avg_post = post.prod() ** (1 / post.numel())
91
+ # z_avg = min(avg_pre / avg_post, avg_post / avg_pre)
92
+
93
+ # post_neg_avg = post[~mask].prod() ** (1 / (~mask).sum())
94
+ # post_pos_avg = post[mask].prod() ** (1 / mask.sum())
95
+
96
+ # pre_neg_avg = pre[~mask].prod() ** (1 / (~mask).sum())
97
+ # pre_pos_avg = pre[mask].prod() ** (1 / mask.sum())
98
+
99
+ # post_neg_over_pos = post_neg_avg / (eps + post_pos_avg)
100
+ # pre_neg_over_pos = pre_neg_avg / (eps + pre_pos_avg)
101
+ # z_post = 1 / (1 + post_neg_over_pos)
102
+ # z_pre = 1 / (1 + pre_neg_over_pos)
103
+
104
+ # z_sent = min(z_post / z_pre, z_pre / z_post)
105
+
106
+ # stats.append((z_avg * z_sent) ** 0.5)
107
+
108
+ # return sum(stats) / len(stats)
109
+
110
+
111
+ # For zsRE and F-NLI
112
+ def retain_rate(pre_logits, post_logits, mask=None):
113
+ if pre_logits.shape[-1] == 1:
114
+ pre_logits = pre_logits.squeeze(-1)
115
+ if post_logits.shape[-1] == 1:
116
+ post_logits = post_logits.squeeze(-1)
117
+
118
+ assert pre_logits.shape == post_logits.shape
119
+ assert pre_logits.shape[0] == mask.shape[0]
120
+
121
+ if pre_logits.dim() == 1:
122
+ # binary classification
123
+ pre_preds = pre_logits > 0
124
+ post_preds = post_logits > 0
125
+ retain = (pre_preds == post_preds).float().mean()
126
+ elif pre_logits.dim() == 3:
127
+ # sequence modeling
128
+ pre_preds = pre_logits.argmax(-1)
129
+ post_preds = post_logits.argmax(-1)
130
+ match = (pre_preds == post_preds) * mask
131
+ retain = (match.sum(-1) == mask.sum(-1)).float().mean()
132
+ else:
133
+ raise NotImplementedError
134
+
135
+ return retain.item()
models.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+ import torch
3
+ import torch.nn as nn
4
+ import re
5
+ import logging
6
+ from nn import FixableDropout
7
+ from utils import scr
8
+
9
+
10
+ LOG = logging.getLogger(__name__)
11
+
12
+
13
+ class CastModule(nn.Module):
14
+ def __init__(self, module: nn.Module, in_cast: torch.dtype = torch.float32, out_cast: torch.dtype = None):
15
+ super().__init__()
16
+
17
+ self.underlying = module
18
+ self.in_cast = in_cast
19
+ self.out_cast = out_cast
20
+
21
+ def cast(self, obj, dtype):
22
+ if dtype is None:
23
+ return obj
24
+
25
+ if isinstance(obj, torch.Tensor):
26
+ return obj.to(dtype)
27
+ else:
28
+ return obj
29
+
30
+ def forward(self, *args, **kwargs):
31
+ args = tuple(self.cast(a, self.in_cast) for a in args)
32
+ kwargs = {k: self.cast(v, self.in_cast) for k, v in kwargs.items()}
33
+ outputs = self.underlying(*args, **kwargs)
34
+ if isinstance(outputs, torch.Tensor):
35
+ outputs = self.cast(outputs, self.out_cast)
36
+ elif isinstance(outputs, tuple):
37
+ outputs = tuple(self.cast(o, self.out_cast) for o in outputs)
38
+ else:
39
+ raise RuntimeError(f"Not sure how to cast type {type(outputs)}")
40
+ return outputs
41
+
42
+ def extra_repr(self):
43
+ return f"in_cast: {self.in_cast}\nout_cast: {self.out_cast}"
44
+
45
+
46
+ class BertClassifier(torch.nn.Module):
47
+ def __init__(self, model_name, hidden_dim=768):
48
+ super().__init__()
49
+ if model_name.startswith("bert"):
50
+ self.model = transformers.BertModel.from_pretrained(model_name, cache_dir=scr())
51
+ else:
52
+ self.model = transformers.AutoModel.from_pretrained(model_name, cache_dir=scr())
53
+ self.classifier = torch.nn.Linear(hidden_dim, 1)
54
+
55
+ @property
56
+ def config(self):
57
+ return self.model.config
58
+
59
+ def forward(self, *args, **kwargs):
60
+ filtered_kwargs = {k: v for k, v in kwargs.items() if k != "labels"}
61
+ model_output = self.model(*args, **filtered_kwargs)
62
+ if "pooler_output" in model_output.keys():
63
+ pred = self.classifier(model_output.pooler_output)
64
+ else:
65
+ pred = self.classifier(model_output.last_hidden_state[:, 0])
66
+
67
+ if "output_hidden_states" in kwargs and kwargs["output_hidden_states"]:
68
+ last_hidden_state = model_output.last_hidden_state
69
+ return pred, last_hidden_state
70
+ else:
71
+ return pred
72
+
73
+
74
+ def replace_dropout(model):
75
+ for m in model.modules():
76
+ for n, c in m.named_children():
77
+ if isinstance(c, nn.Dropout):
78
+ setattr(m, n, FixableDropout(c.p))
79
+
80
+ def resample(m, seed=None):
81
+ for c in m.children():
82
+ if hasattr(c, "resample"):
83
+ c.resample(seed)
84
+ else:
85
+ resample(c, seed)
86
+
87
+ model.resample_dropout = resample.__get__(model)
88
+
89
+
90
+ def get_model(config):
91
+ if config.model.class_name == "BertClassifier":
92
+ model = BertClassifier(config.model.name)
93
+ else:
94
+ ModelClass = getattr(transformers, config.model.class_name)
95
+ LOG.info(f"Loading model class {ModelClass} with name {config.model.name} from cache dir {scr()}")
96
+ model = ModelClass.from_pretrained(config.model.name, cache_dir=scr())
97
+
98
+ if config.model.pt is not None:
99
+ LOG.info(f"Loading model initialization from {config.model.pt}")
100
+ state_dict = torch.load(config.model.pt, map_location="cpu")
101
+
102
+ try:
103
+ model.load_state_dict(state_dict)
104
+ except RuntimeError:
105
+ LOG.info("Default load failed; stripping prefix and trying again.")
106
+ state_dict = {re.sub("^model.", "", k): v for k, v in state_dict.items()}
107
+
108
+ model.load_state_dict(state_dict)
109
+
110
+ LOG.info("Loaded model initialization")
111
+
112
+ if config.dropout is not None:
113
+ n_reset = 0
114
+ for m in model.modules():
115
+ if isinstance(m, nn.Dropout):
116
+ m.p = config.dropout
117
+ n_reset += 1
118
+
119
+ if hasattr(m, "dropout"): # Requires for BART, which uses F.dropout
120
+ if isinstance(m.dropout, float):
121
+ m.dropout = config.dropout
122
+ n_reset += 1
123
+
124
+ if hasattr(m, "activation_dropout"): # Requires for BART, which uses F.dropout
125
+ if isinstance(m.activation_dropout, float):
126
+ m.activation_dropout = config.dropout
127
+ n_reset += 1
128
+
129
+ LOG.info(f"Set {n_reset} dropout modules to p={config.dropout}")
130
+
131
+ param_names = [n for n, _ in model.named_parameters()]
132
+ bad_inner_params = [p for p in config.model.inner_params if p not in param_names]
133
+ if len(bad_inner_params) != 0:
134
+ raise ValueError(f"Params {bad_inner_params} do not exist in model of type {type(model)}.")
135
+
136
+ if config.no_grad_layers is not None:
137
+ if config.half:
138
+ model.bfloat16()
139
+
140
+ def upcast(mod):
141
+ modlist = None
142
+ for child in mod.children():
143
+ if isinstance(child, nn.ModuleList):
144
+ assert modlist is None, f"Found multiple modlists for {mod}"
145
+ modlist = child
146
+ if modlist is None:
147
+ raise RuntimeError("Couldn't find a ModuleList child")
148
+
149
+ LOG.info(f"Setting {len(modlist) - config.no_grad_layers} modules to full precision, with autocasting")
150
+ modlist[config.no_grad_layers:].to(torch.float32)
151
+ modlist[config.no_grad_layers] = CastModule(modlist[config.no_grad_layers])
152
+ modlist[-1] = CastModule(modlist[-1], in_cast=torch.float32, out_cast=torch.bfloat16)
153
+
154
+ parents = []
155
+ if hasattr(model, "transformer"):
156
+ parents.append(model.transformer)
157
+ if hasattr(model, "encoder"):
158
+ parents.append(model.encoder)
159
+ if hasattr(model, "decoder"):
160
+ parents.append(model.decoder)
161
+ if hasattr(model, "model"):
162
+ parents.extend([model.model.encoder, model.model.decoder])
163
+
164
+ for t in parents:
165
+ t.no_grad_layers = config.no_grad_layers
166
+ if config.half and config.alg != "rep":
167
+ upcast(t)
168
+
169
+ if config.half and config.alg != "rep":
170
+ idxs = []
171
+ for p in config.model.inner_params:
172
+ for comp in p.split('.'):
173
+ if comp.isdigit():
174
+ idxs.append(int(comp))
175
+ max_idx, min_idx = str(max(idxs)), str(config.no_grad_layers)
176
+ for pidx, p in enumerate(config.model.inner_params):
177
+ comps = p.split('.')
178
+ if max_idx in comps or min_idx in comps:
179
+ index = comps.index(max_idx) if max_idx in comps else comps.index(min_idx)
180
+ comps.insert(index + 1, 'underlying')
181
+ new_p = '.'.join(comps)
182
+ LOG.info(f"Replacing config.model.inner_params[{pidx}] '{p}' -> '{new_p}'")
183
+ config.model.inner_params[pidx] = new_p
184
+
185
+ return model
186
+
187
+
188
+ def get_tokenizer(config):
189
+ tok_name = config.model.tokenizer_name if config.model.tokenizer_name is not None else config.model.name
190
+ return getattr(transformers, config.model.tokenizer_class).from_pretrained(tok_name, cache_dir=scr())
191
+
192
+
193
+ if __name__ == '__main__':
194
+ m = BertClassifier("bert-base-uncased")
195
+ m(torch.arange(5)[None, :])
196
+ import pdb; pdb.set_trace()
nn.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ import logging
5
+ import time
6
+
7
+ from utils import factorization
8
+
9
+ LOG = logging.getLogger(__name__)
10
+
11
+
12
+ class FixableDropout(nn.Module):
13
+ def __init__(self, p: float):
14
+ super().__init__()
15
+
16
+ self.p = p
17
+ self.mask_cache = {}
18
+ self.seed = 0
19
+
20
+ def resample(self, seed=None):
21
+ if seed is None:
22
+ seed = int(time.time() * 1e6)
23
+ self.mask_cache = {}
24
+ self.seed = seed
25
+
26
+ def forward(self, x):
27
+ if self.training:
28
+ if x.shape not in self.mask_cache:
29
+ generator = torch.Generator(x.device).manual_seed(self.seed)
30
+ self.mask_cache[x.shape] = torch.bernoulli(
31
+ torch.full_like(x, 1 - self.p), generator=generator
32
+ ).bool()
33
+ self.should_resample = False
34
+
35
+ x = (self.mask_cache[x.shape] * x) / (1 - self.p)
36
+
37
+ return x
38
+
39
+ def extra_repr(self) -> str:
40
+ return f"p={self.p}"
41
+
42
+
43
+ class ActMLP(nn.Module):
44
+ def __init__(self, hidden_dim, n_hidden):
45
+ super().__init__()
46
+
47
+ self.mlp = MLP(1, 1, hidden_dim, n_hidden, init="id")
48
+
49
+ def forward(self, x):
50
+ return self.mlp(x.view(-1, 1)).view(x.shape)
51
+
52
+
53
+ class LightIDMLP(nn.Module):
54
+ def __init__(
55
+ self,
56
+ indim: int,
57
+ outdim: int,
58
+ hidden_dim: int,
59
+ n_hidden: int,
60
+ init: str = None,
61
+ act: str = None,
62
+ rank: int = None,
63
+ ):
64
+ super().__init__()
65
+ LOG.info(f"Building LightIDMLP {[indim] + [rank] + [indim]}")
66
+ self.layer1 = nn.Linear(indim, rank)
67
+ self.layer2 = nn.Linear(rank, indim)
68
+ self.layer2.weight.data[:] = 0
69
+ self.layer2.bias = None
70
+
71
+ def forward(self, x):
72
+ h = self.layer1(x).relu()
73
+ return x + self.layer2(h)
74
+
75
+
76
+ class IDMLP(nn.Module):
77
+ def __init__(
78
+ self,
79
+ indim: int,
80
+ outdim: int,
81
+ hidden_dim: int,
82
+ n_hidden: int,
83
+ init: str = None,
84
+ act: str = None,
85
+ rank: int = None,
86
+ n_modes: int = None
87
+ ):
88
+ super().__init__()
89
+ LOG.info(f"Building IDMLP ({init}) {[indim] * (n_hidden + 2)}")
90
+ self.layers = nn.ModuleList(
91
+ [
92
+ LRLinear(indim, indim, rank=rank, relu=idx < n_hidden, init=init, n_modes=n_modes)
93
+ for idx in range(n_hidden + 1)
94
+ ]
95
+ )
96
+
97
+ def forward(self, x, mode=None):
98
+ for layer in self.layers:
99
+ x = layer(x, mode=mode)
100
+
101
+ return x
102
+
103
+
104
+ class LatentIDMLP(nn.Module):
105
+ def __init__(
106
+ self,
107
+ indim: int,
108
+ outdim: int,
109
+ hidden_dim: int,
110
+ n_hidden: int,
111
+ init: str = None,
112
+ act: str = None,
113
+ rank: int = None,
114
+ ):
115
+ super().__init__()
116
+ LOG.info(f"Building Latent IDMLP ({init}) {[indim] * (n_hidden + 2)}")
117
+
118
+ self.layers = nn.ModuleList()
119
+ self.layers.append(nn.Linear(indim, rank))
120
+ for _ in range(n_hidden - 1):
121
+ self.layers.append(nn.Linear(rank, rank))
122
+ self.layers.append(nn.Linear(rank, outdim))
123
+
124
+ for layer in self.layers[:-1]:
125
+ nn.init.xavier_normal_(layer.weight.data)
126
+
127
+ if init == "id":
128
+ self.layers[-1].weight.data.zero_()
129
+ self.layers[-1].bias.data.zero_()
130
+
131
+ self.init = init
132
+
133
+ def forward(self, x):
134
+ out = x
135
+ for layer in self.layers[:-1]:
136
+ out = layer(out).relu()
137
+
138
+ out = self.layers[-1](out)
139
+ if self.init == "id":
140
+ return out + x
141
+ else:
142
+ return out
143
+
144
+
145
+ class KLinear(nn.Module):
146
+ def __init__(self, inf, outf, pfrac=0.05, symmetric=True, zero_init: bool = True):
147
+ super().__init__()
148
+
149
+ self.inf = inf
150
+
151
+ in_fact = factorization(inf)
152
+ out_fact = factorization(outf)
153
+
154
+ total_params = 0
155
+ self.a, self.b = nn.ParameterList(), nn.ParameterList()
156
+ for (i1, i2), (o1, o2) in zip(reversed(in_fact), reversed(out_fact)):
157
+ new_params = (o1 * i1 + o2 * i2) * (2 if symmetric else 1)
158
+ if (total_params + new_params) / (inf * outf) > pfrac and len(self.a) > 0:
159
+ break
160
+ total_params += new_params
161
+
162
+ self.a.append(nn.Parameter(torch.empty(o1, i1)))
163
+ if symmetric:
164
+ self.a.append(nn.Parameter(torch.empty(o2, i2)))
165
+
166
+ self.b.append(nn.Parameter(torch.empty(o2, i2)))
167
+ if symmetric:
168
+ self.b.append(nn.Parameter(torch.empty(o1, i1)))
169
+
170
+ assert self.a[-1].kron(self.b[-1]).shape == (outf, inf)
171
+
172
+ for factor in self.a:
173
+ nn.init.kaiming_normal_(factor.data)
174
+ for factor in self.b:
175
+ if zero_init:
176
+ factor.data.zero_()
177
+ else:
178
+ nn.init.kaiming_normal_(factor.data)
179
+
180
+ print(f"Created ({symmetric}) k-layer using {total_params/(outf*inf):.3f} params, {len(self.a)} comps")
181
+ self.bias = nn.Parameter(torch.zeros(outf))
182
+
183
+ def forward(self, x):
184
+ assert x.shape[-1] == self.inf, f"Expected input with {self.inf} dimensions, got {x.shape}"
185
+ w = sum([a.kron(b) for a, b in zip(self.a, self.b)]) / (2 * len(self.a) ** 0.5)
186
+ y = w @ x.T
187
+ if self.bias is not None:
188
+ y = y + self.bias
189
+ return y
190
+
191
+
192
+ class LRLinear(nn.Module):
193
+ def __init__(self, inf, outf, rank: int = None, relu=False, init="id", n_modes=None):
194
+ super().__init__()
195
+
196
+ mid_dim = min(rank, inf)
197
+ if init == "id":
198
+ self.u = nn.Parameter(torch.zeros(outf, mid_dim))
199
+ self.v = nn.Parameter(torch.randn(mid_dim, inf))
200
+ elif init == "xavier":
201
+ self.u = nn.Parameter(torch.empty(outf, mid_dim))
202
+ self.v = nn.Parameter(torch.empty(mid_dim, inf))
203
+ nn.init.xavier_uniform_(self.u.data, gain=nn.init.calculate_gain("relu"))
204
+ nn.init.xavier_uniform_(self.v.data, gain=1.0)
205
+ else:
206
+ raise ValueError(f"Unrecognized initialization {init}")
207
+
208
+ if n_modes is not None:
209
+ self.mode_shift = nn.Embedding(n_modes, outf)
210
+ self.mode_shift.weight.data.zero_()
211
+ self.mode_scale = nn.Embedding(n_modes, outf)
212
+ self.mode_scale.weight.data.fill_(1)
213
+
214
+ self.n_modes = n_modes
215
+ self.bias = nn.Parameter(torch.zeros(outf))
216
+ self.inf = inf
217
+ self.init = init
218
+
219
+ def forward(self, x, mode=None):
220
+ if mode is not None:
221
+ assert self.n_modes is not None, "Linear got a mode but wasn't initialized for it"
222
+ assert mode < self.n_modes, f"Input mode {mode} outside of range {self.n_modes}"
223
+ assert x.shape[-1] == self.inf, f"Input wrong dim ({x.shape}, {self.inf})"
224
+
225
+ pre_act = (self.u @ (self.v @ x.T)).T
226
+ if self.bias is not None:
227
+ pre_act += self.bias
228
+
229
+ if mode is not None:
230
+ if not isinstance(mode, torch.Tensor):
231
+ mode = torch.tensor(mode).to(x.device)
232
+ scale, shift = self.mode_scale(mode), self.mode_shift(mode)
233
+ pre_act = pre_act * scale + shift
234
+
235
+ # need clamp instead of relu so gradient at 0 isn't 0
236
+ acts = pre_act.clamp(min=0)
237
+ if self.init == "id":
238
+ return acts + x
239
+ else:
240
+ return acts
241
+
242
+
243
+ class MLP(nn.Module):
244
+ def __init__(
245
+ self,
246
+ indim: int,
247
+ outdim: int,
248
+ hidden_dim: int,
249
+ n_hidden: int,
250
+ init: str = "xavier_uniform",
251
+ act: str = "relu",
252
+ rank: int = None,
253
+ ):
254
+ super().__init__()
255
+
256
+ self.init = init
257
+
258
+ if act == "relu":
259
+ self.act = nn.ReLU()
260
+ elif act == "learned":
261
+ self.act = ActMLP(10, 1)
262
+ else:
263
+ raise ValueError(f"Unrecognized activation function '{act}'")
264
+
265
+ if hidden_dim is None:
266
+ hidden_dim = outdim * 2
267
+
268
+ if init.startswith("id") and outdim != indim:
269
+ LOG.info(f"Overwriting outdim ({outdim}) to be indim ({indim})")
270
+ outdim = indim
271
+
272
+ if init == "id":
273
+ old_hidden_dim = hidden_dim
274
+ if hidden_dim < indim * 2:
275
+ hidden_dim = indim * 2
276
+
277
+ if hidden_dim % indim != 0:
278
+ hidden_dim += hidden_dim % indim
279
+
280
+ if old_hidden_dim != hidden_dim:
281
+ LOG.info(
282
+ f"Overwriting hidden dim ({old_hidden_dim}) to be {hidden_dim}"
283
+ )
284
+
285
+ if init == "id_alpha":
286
+ self.alpha = nn.Parameter(torch.zeros(1, outdim))
287
+
288
+ dims = [indim] + [hidden_dim] * n_hidden + [outdim]
289
+ LOG.info(f"Building ({init}) MLP: {dims} (rank {rank})")
290
+
291
+ layers = []
292
+ for idx, (ind, outd) in enumerate(zip(dims[:-1], dims[1:])):
293
+ if rank is None:
294
+ layers.append(nn.Linear(ind, outd))
295
+ else:
296
+ layers.append(LRLinear(ind, outd, rank=rank))
297
+ if idx < n_hidden:
298
+ layers.append(self.act)
299
+
300
+ if rank is None:
301
+ if init == "id":
302
+ if n_hidden > 0:
303
+ layers[0].weight.data = torch.eye(indim).repeat(
304
+ hidden_dim // indim, 1
305
+ )
306
+ layers[0].weight.data[hidden_dim // 2:] *= -1
307
+ layers[-1].weight.data = torch.eye(outdim).repeat(
308
+ 1, hidden_dim // outdim
309
+ )
310
+ layers[-1].weight.data[:, hidden_dim // 2:] *= -1
311
+ layers[-1].weight.data /= (hidden_dim // indim) / 2.0
312
+
313
+ for layer in layers:
314
+ if isinstance(layer, nn.Linear):
315
+ if init == "ortho":
316
+ nn.init.orthogonal_(layer.weight)
317
+ elif init == "id":
318
+ if layer.weight.shape[0] == layer.weight.shape[1]:
319
+ layer.weight.data = torch.eye(hidden_dim)
320
+ else:
321
+ gain = 3 ** 0.5 if (layer is layers[-1]) else 1.0
322
+ nn.init.xavier_uniform_(layer.weight, gain=gain)
323
+
324
+ layer.bias.data[:] = 0
325
+
326
+ layers[-1].bias = None
327
+ self.mlp = nn.Sequential(*layers)
328
+
329
+ def forward(self, x):
330
+ if self.init == "id_alpha":
331
+ return x + self.alpha * self.mlp(x)
332
+ else:
333
+ return self.mlp(x)
334
+
335
+
336
+ if __name__ == "__main__":
337
+ logging.basicConfig(
338
+ format="%(asctime)s - %(levelname)s [%(filename)s:%(lineno)d] %(message)s",
339
+ level=logging.INFO,
340
+ )
341
+ m0 = MLP(1000, 1000, 1500, 3)
342
+ m1 = MLP(1000, 1000, 1500, 3, init="id")
343
+ m2 = MLP(1000, 1000, 1500, 3, init="id_alpha")
344
+ m3 = MLP(1000, 1000, 1500, 3, init="ortho", act="learned")
345
+
346
+ x = 0.01 * torch.randn(999, 1000)
347
+
348
+ y0 = m0(x)
349
+ y1 = m1(x)
350
+ y2 = m2(x)
351
+ y3 = m3(x)
352
+
353
+ print("y0", (y0 - x).abs().max())
354
+ print("y1", (y1 - x).abs().max())
355
+ print("y2", (y2 - x).abs().max())
356
+ print("y3", (y3 - x).abs().max())
357
+
358
+ assert not torch.allclose(y0, x)
359
+ assert torch.allclose(y1, x)
360
+ assert torch.allclose(y2, x)
361
+ assert not torch.allclose(y3, x)
362
+ import pdb; pdb.set_trace() # fmt: skip
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ allennlp
2
+ git+https://github.com/eric-mitchell/higher@master # For in-place functional models
3
+ pandas
4
+ streamlit
5
+ torch
6
+ transformers
utils.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import typing
3
+ import numpy as np
4
+ import struct
5
+ import os
6
+ import getpass
7
+ import logging
8
+ import torch
9
+ import torch.nn as nn
10
+ from collections import defaultdict
11
+ import math
12
+
13
+
14
+ LOG = logging.getLogger(__name__)
15
+
16
+ def masked_mean(values, mask):
17
+ assert mask.dtype == torch.bool
18
+ assert values.shape == mask.shape
19
+ return (values * mask.float()).sum() / mask.sum().float()
20
+
21
+
22
+ def mask_hf_labels(labels, null_token=0):
23
+ valid_mask = labels != -100
24
+ valid_labels = labels.masked_fill(~valid_mask, null_token)
25
+ return valid_mask, valid_labels
26
+
27
+
28
+ def gather_log_probs(logits, labels):
29
+ assert labels.dim() == logits.dim() - 1
30
+ assert labels.shape == logits.shape[:-1]
31
+ return logits.log_softmax(-1).gather(-1, labels.unsqueeze(-1)).squeeze(-1)
32
+
33
+
34
+ def off_diagonal(mat):
35
+ assert mat.dim() == 2
36
+ # assert mat.shape[0] == mat.shape[1]
37
+
38
+ mask = ~torch.eye(max(mat.shape), dtype=torch.bool)
39
+ mask = mask[:mat.shape[0], :mat.shape[1]]
40
+ off_d = mat[mask]
41
+
42
+ assert off_d.numel() == mat.shape[0] * mat.shape[1] - min(mat.shape)
43
+
44
+ return off_d
45
+
46
+
47
+ def set_dropout(model, p):
48
+ if p is not None:
49
+ n_reset = 0
50
+ for m in model.modules():
51
+ if isinstance(m, nn.Dropout):
52
+ m.p = p
53
+ n_reset += 1
54
+
55
+ if hasattr(m, "dropout"): # Requires for BART, which uses F.dropout
56
+ if isinstance(m.dropout, float):
57
+ m.dropout = p
58
+ n_reset += 1
59
+
60
+ if hasattr(m, "activation_dropout"): # Requires for BART, which uses F.dropout
61
+ if isinstance(m.activation_dropout, float):
62
+ m.activation_dropout = p
63
+ n_reset += 1
64
+
65
+ LOG.info(f"Set {n_reset} dropout modules to p={p}")
66
+
67
+
68
+ def _inner_params(named_parameters, inner_names):
69
+ param_dict = dict(named_parameters)
70
+ return [(n, param_dict[n]) for n in inner_names]
71
+
72
+
73
+ def shift_targets(config):
74
+ return "t5" not in config.model.name.lower() and "blender" not in config.model.name.lower()
75
+
76
+
77
+ # https://stackoverflow.com/questions/32871539/integer-factorization-in-python
78
+ def factorization(n):
79
+ return [(i, n // i) for i in range(1, int(n**0.5) + 1) if n % i == 0]
80
+
81
+
82
+ def scr():
83
+ if os.path.exists("/scr-ssd"):
84
+ scr_dir = "/scr-ssd/" + getpass.getuser()
85
+ else:
86
+ scr_dir = "/scr/" + getpass.getuser()
87
+
88
+ if not os.path.exists(scr_dir):
89
+ os.makedirs(scr_dir)
90
+
91
+ return scr_dir
92
+
93
+
94
+ def uuid(digits=4):
95
+ if not hasattr(uuid, "uuid_value"):
96
+ uuid.uuid_value = struct.unpack('I', os.urandom(4))[0] % int(10**digits)
97
+
98
+ return uuid.uuid_value
99
+
100
+
101
+ def formatted_timestamp(time=None):
102
+ if time is None:
103
+ time = datetime.datetime.now()
104
+ return time.strftime("%d/%m/%Y-%H:%M:%S/%f")
105
+
106
+
107
+ def time_delta_seconds(start, finish=None):
108
+ assert type(start) == str
109
+
110
+ t1 = datetime.datetime.strptime(start, "%d/%m/%Y-%H:%M:%S/%f")
111
+ if finish is not None:
112
+ assert type(finish) == str
113
+ t2 = datetime.datetime.strptime(finish, "%d/%m/%Y-%H:%M:%S/%f")
114
+ else:
115
+ t2 = datetime.datetime.now()
116
+
117
+ return (t2 - t1).total_seconds()
118
+
119
+
120
+ def dict_to(d, device):
121
+ new_dict = {}
122
+ for k, v in d.items():
123
+ if isinstance(v, torch.Tensor):
124
+ new_dict[k] = v.to(device)
125
+ elif isinstance(v, dict):
126
+ new_dict[k] = dict_to(v, device)
127
+ else:
128
+ new_dict[k] = v
129
+
130
+ return new_dict
131
+
132
+
133
+ def safe_backward(loss, parameters, accumulate=1, allow_unused=False, backward=False):
134
+ if backward:
135
+ (loss / accumulate).backward()
136
+ else:
137
+ parameters = list(parameters) # Capture the generator output
138
+ grads = torch.autograd.grad(loss, parameters, allow_unused=allow_unused)
139
+ nan, inf = False, False
140
+ for g in grads:
141
+ if g is not None:
142
+ nan |= g.isnan().any().item()
143
+ inf |= g.isinf().any().item()
144
+
145
+ if not (nan or inf):
146
+ for p, g in zip(parameters, grads):
147
+ if g is None:
148
+ continue
149
+
150
+ if p.grad is None:
151
+ p.grad = g / accumulate
152
+ else:
153
+ p.grad += g / accumulate
154
+ else:
155
+ LOG.info(f"Skipping grad accumulation because inf: {inf} nan: {nan}")
156
+
157
+
158
+ def _logits(x):
159
+ return x if not hasattr(x, "logits") else x.logits
160
+
161
+
162
+ def _last_encoder_state(x):
163
+ if hasattr(x, "encoder_last_hidden_state"):
164
+ return x.encoder_last_hidden_state
165
+ else:
166
+ return x.hidden_states[-1]
167
+
168
+
169
+ def load_archive(path):
170
+ import torch
171
+
172
+ if not os.path.exists(path):
173
+ # We've not passed an explicit path, but a part of the filename
174
+ wd = '/iris/u/clin/code/efk/'
175
+ directories = ["outputs", "multirun"]
176
+ matches = []
177
+ for d in directories:
178
+ search = os.path.join(wd, d)
179
+ for run_dir in os.listdir(search):
180
+ if path in run_dir:
181
+ matches.append(os.path.join(search, run_dir))
182
+ assert len(matches) == 1, f">1 matches for search {path}; specify exact path"
183
+
184
+ full_run_dir = matches[0]
185
+ if "0" in os.listdir(full_run_dir):
186
+ full_run_dir = os.path.join(full_run_dir, "0")
187
+ models_dir = os.path.join(full_run_dir, "models")
188
+ models = os.listdir(models_dir)
189
+ non_bk = [m for m in models if not m.endswith(".bk")]
190
+ assert (
191
+ len(non_bk) == 1
192
+ ), f"Expected a single model in {models_dir}, got {len(non_bk)}"
193
+ path = os.path.join(models_dir, non_bk[0])
194
+
195
+ LOG.info(f"Loading checkpoint from {path}")
196
+ archive = torch.load(path, map_location="cpu")
197
+ LOG.info("Load complete.")
198
+
199
+ return archive, path
200
+
201
+
202
+ def flatten_dict(d):
203
+ to_process = list(d.items())
204
+ output = {}
205
+ while len(to_process):
206
+ k, v = to_process.pop()
207
+ if isinstance(v, typing.MutableMapping):
208
+ to_process.extend([(f"{k}.{k_}", v_) for (k_, v_) in v.items()])
209
+ else:
210
+ assert k not in output.keys(), "Somehow ended up with duplicate keys"
211
+ output[k] = v
212
+
213
+ return output
214
+
215
+
216
+ def add_padding(tokenizer, model):
217
+ tokenizer.add_special_tokens({'pad_token': '[PAD]'})
218
+ model.resize_token_embeddings(len(tokenizer))
219
+ model.transformer.wte.weight.data[-1] = model.transformer.wte.weight.data.mean(0)
220
+
221
+
222
+ def add_sep(tokenizer, model):
223
+ tokenizer.add_special_tokens({'sep_token': '[SEP]'})
224
+ # model.resize_token_embeddings(len(tokenizer))
225
+ # model.lm_head.weight.data[-1, :] = model.lm_head.weight.data.mean(0)
226
+
227
+
228
+ class EarlyStopper:
229
+ def __init__(self, patience: int, key: str, minimize: bool = False):
230
+ self.best_value = 1e9 if minimize else -1e9
231
+ self.best_iter = 0
232
+ self.current_iter = 0
233
+ self.key = key
234
+ self.patience = patience
235
+ self.minimize = minimize
236
+ self._stop = False
237
+
238
+ def update(self, idx, stats):
239
+ assert self.key in stats, f"'{self.key}' not in stats dict"
240
+ value = stats[self.key]
241
+ new_best = value < self.best_value if self.minimize else value > self.best_value
242
+ if new_best:
243
+ self.best_value = value
244
+ self.best_iter = idx
245
+
246
+ self.current_iter = idx
247
+ return new_best
248
+
249
+ def should_stop(self):
250
+ self._stop |= self.current_iter - self.best_iter >= self.patience
251
+ return self._stop
252
+
253
+
254
+ class RunningStatAverager:
255
+ def __init__(self, suffix="", exclude=["grad/"], compute_ppl: bool = True):
256
+ self.underlying = None
257
+ self.suffix = suffix
258
+ self.exclude = exclude
259
+ self.compute_ppl = compute_ppl
260
+
261
+ self.reset()
262
+
263
+ def add(self, d: dict):
264
+ for k, v in d.items():
265
+ if not any([k.startswith(prefix) for prefix in self.exclude]):
266
+ if len(self.suffix):
267
+ self.underlying[f"{k}_{self.suffix}"].append(v)
268
+ else:
269
+ self.underlying[k].append(v)
270
+
271
+ def average(self):
272
+ average = {}
273
+ for k, v in self.underlying.items():
274
+ if not k.startswith("nll/"):
275
+ average[k] = sum(v) / len(v)
276
+ else:
277
+ assert len(k.split("/")) == 2, f"Invalid key {k}"
278
+ name = k.split("/")[1]
279
+ token_counts = self.underlying[f"n_tokens/{name}"]
280
+ total_nll = sum([nll * c for nll, c in zip(v, token_counts)])
281
+ average[k] = total_nll / sum(token_counts)
282
+ if self.compute_ppl:
283
+ average[f"perplexity/{name}"] = math.e ** average[k]
284
+
285
+ return {k: v if not isinstance(v, torch.Tensor) else v.item() for k, v in average.items()}
286
+
287
+ def reset(self):
288
+ self.underlying = defaultdict(list)
289
+
290
+
291
+ class EditBatchSampler:
292
+ def __init__(
293
+ self,
294
+ n,
295
+ memorize_mode=False,
296
+ loc_disjoint=True,
297
+ seed=0,
298
+ hard_neg=False,
299
+ hard_neg_prob=1.0,
300
+ loc_distr_matrix=None,
301
+ loc_idx_matrix=None,
302
+ keep_probs=None,
303
+ mutex=None
304
+ ):
305
+ self.memorize_mode = memorize_mode
306
+ self.n = n
307
+ self.loc_disjoint = loc_disjoint
308
+ self.rng = np.random.default_rng(seed)
309
+ self.hard_neg = hard_neg
310
+ self.hard_neg_prob = hard_neg_prob
311
+ self.loc_probs = loc_distr_matrix
312
+ self.loc_idxs = loc_idx_matrix
313
+ self.keep_probs = np.array(keep_probs)[:self.n] if keep_probs is not None else None
314
+ self.mutex = mutex[:self.n] if mutex is not None else None
315
+ self._init()
316
+
317
+ def _init(self):
318
+ idxs = np.arange(self.n)
319
+ if self.keep_probs is not None:
320
+ sample = self.rng.binomial(1, self.keep_probs).astype(np.bool)
321
+ idxs = idxs[sample]
322
+
323
+ self.perm = self.rng.permutation(idxs)
324
+ self.edit_position = 0
325
+
326
+ def get_edit_idxs(self, batch_size):
327
+ if self.mutex is None:
328
+ idxs = set([int(idx) for idx in self.perm[self.edit_position: self.edit_position + batch_size]])
329
+ self.edit_position += batch_size
330
+ else:
331
+ mutexes = []
332
+ idxs = []
333
+
334
+ def notin(x, mutexes):
335
+ for m in mutexes:
336
+ if x in m or m in x:
337
+ return False
338
+ return True
339
+ while len(idxs) < batch_size:
340
+ new_idx = self.perm[self.edit_position]
341
+ if notin(self.mutex[new_idx], mutexes):
342
+ mutexes.append(self.mutex[new_idx])
343
+ idxs.append(int(new_idx))
344
+ self.edit_position += 1
345
+ if self.edit_position == self.perm.shape[0]:
346
+ return None
347
+
348
+ idxs = set(idxs)
349
+
350
+ return idxs
351
+
352
+ def sample(self, batch_size, return_hard_flag=False):
353
+ if self.memorize_mode:
354
+ return list(range(batch_size)), list(range(batch_size, batch_size * 2))
355
+
356
+ if self.edit_position + batch_size >= self.perm.shape[0]:
357
+ self._init() # Re-start if we end with a partially-sized batch
358
+
359
+ edit_idxs = self.get_edit_idxs(batch_size)
360
+ if edit_idxs is None:
361
+ self._init()
362
+ edit_idxs = self.get_edit_idxs(batch_size)
363
+ if edit_idxs is None:
364
+ raise RuntimeError(f"No valid batches of size {batch_size} exist!")
365
+
366
+ if self.hard_neg:
367
+ assert self.loc_probs is not None, "hard_neg is on, but don't have distance matrix!"
368
+
369
+ def get_loc_idxs():
370
+ if self.hard_neg and self.rng.uniform() < self.hard_neg_prob:
371
+ return [int(self.rng.choice(self.loc_idxs[idx], p=self.loc_probs[idx])) for idx in edit_idxs], True
372
+ else:
373
+ # Use deterministic implementation in case edit batches are large
374
+ non_edit_idxs = list(set(range(self.n)) - set(edit_idxs))
375
+ return [int(idx) for idx in self.rng.choice(non_edit_idxs, batch_size)], False
376
+
377
+ loc_idxs, hard = get_loc_idxs()
378
+ if self.loc_disjoint:
379
+ steps = 0
380
+ while len(edit_idxs.intersection(set(loc_idxs))) > 0:
381
+ loc_idxs, hard = get_loc_idxs()
382
+ steps += 1
383
+ if steps > 100:
384
+ raise RuntimeError("Can't find disjoint loc_idxs and edit_idxs!")
385
+
386
+ if return_hard_flag:
387
+ return list(edit_idxs), loc_idxs, hard
388
+ else:
389
+ return list(edit_idxs), loc_idxs
390
+
391
+
392
+ def parent_module(model, pname):
393
+ comps = pname.split('.')
394
+ parent = model
395
+ for comp in comps[:-1]:
396
+ if hasattr(parent, comp):
397
+ parent = getattr(parent, comp)
398
+ elif comp.isdigit():
399
+ parent = parent[int(comp)]
400
+ else:
401
+ raise RuntimeError(f"Couldn't find child module {comp}")
402
+ assert hasattr(parent, comps[-1])
403
+ return parent
404
+
405
+
406
+ def build_distr_matrix(edit_qs, config, loc_qs=None, slice_size=1000):
407
+ n = len(edit_qs)
408
+ device = "cuda" if torch.cuda.is_available() else "cpu"
409
+
410
+ num_neighbors = config.data.hard_neg_neighbors
411
+ num_exclude = config.data.hard_neg_exclude
412
+ temp = config.data.hard_neg_temp
413
+
414
+ from sentence_transformers import SentenceTransformer
415
+ from sentence_transformers.util import pytorch_cos_sim
416
+ embedding_model = SentenceTransformer('all-MiniLM-L6-v2', cache_folder=scr()).to(device)
417
+
418
+ ind_matrix = torch.zeros((n, num_neighbors - num_exclude), dtype=torch.long)
419
+ distr_matrix = torch.full((n, num_neighbors - num_exclude), float('nan'))
420
+ edit_encodings = torch.FloatTensor(embedding_model.encode(edit_qs, batch_size=256)).to(device)
421
+
422
+ # If loc_qs is None then build the similarity matrix between edit_qs and itself
423
+ loc_encodings = edit_encodings if loc_qs is None else embedding_model.encode(loc_qs, batch_size=256)
424
+ if isinstance(loc_encodings, np.ndarray):
425
+ loc_encodings = torch.FloatTensor(loc_encodings).to(device)
426
+
427
+ for idx in range(0, n, slice_size):
428
+ end_idx = idx + slice_size if idx + slice_size <= n else n
429
+ slice_encodings = edit_encodings[idx:end_idx]
430
+ sim_rows = pytorch_cos_sim(slice_encodings, loc_encodings)
431
+ indices = sim_rows.topk(num_neighbors, -1).indices[:, num_exclude:]
432
+ ind_matrix[idx:end_idx] = indices.cpu()
433
+ distr_matrix[idx:end_idx] = sim_rows.gather(-1, indices).mul(temp).exp().cpu()
434
+
435
+ assert not torch.isnan(distr_matrix).any()
436
+
437
+ LOG.info(f"Built hard negative distribution matrix of size {distr_matrix.shape}")
438
+ distr_matrix = distr_matrix.numpy()
439
+ distr_matrix = distr_matrix / distr_matrix.sum(-1, keepdims=True)
440
+ return distr_matrix, ind_matrix.numpy()
441
+