Martijn van Beers commited on
Commit
2e1a3f8
·
1 Parent(s): e8c51f1

try to make it work quick and dirty

Browse files
BERT_explainability/BERT.py ADDED
@@ -0,0 +1,671 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+
3
+ import torch
4
+ from torch import nn
5
+ import torch.nn.functional as F
6
+ import math
7
+ from transformers import BertConfig
8
+ from transformers.modeling_outputs import BaseModelOutputWithPooling, BaseModelOutput
9
+ from BERT_explainability.modules.layers_ours import *
10
+ from transformers import (
11
+ BertPreTrainedModel,
12
+ PreTrainedModel,
13
+ )
14
+
15
+ ACT2FN = {
16
+ "relu": ReLU,
17
+ "tanh": Tanh,
18
+ "gelu": GELU,
19
+ }
20
+
21
+
22
+ def get_activation(activation_string):
23
+ if activation_string in ACT2FN:
24
+ return ACT2FN[activation_string]
25
+ else:
26
+ raise KeyError("function {} not found in ACT2FN mapping {}".format(activation_string, list(ACT2FN.keys())))
27
+
28
+ def compute_rollout_attention(all_layer_matrices, start_layer=0):
29
+ # adding residual consideration
30
+ num_tokens = all_layer_matrices[0].shape[1]
31
+ batch_size = all_layer_matrices[0].shape[0]
32
+ eye = torch.eye(num_tokens).expand(batch_size, num_tokens, num_tokens).to(all_layer_matrices[0].device)
33
+ all_layer_matrices = [all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))]
34
+ all_layer_matrices = [all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True)
35
+ for i in range(len(all_layer_matrices))]
36
+ joint_attention = all_layer_matrices[start_layer]
37
+ for i in range(start_layer+1, len(all_layer_matrices)):
38
+ joint_attention = all_layer_matrices[i].bmm(joint_attention)
39
+ return joint_attention
40
+
41
+ class BertEmbeddings(nn.Module):
42
+ """Construct the embeddings from word, position and token_type embeddings."""
43
+
44
+ def __init__(self, config):
45
+ super().__init__()
46
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
47
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
48
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
49
+
50
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
51
+ # any TensorFlow checkpoint file
52
+ self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
53
+ self.dropout = Dropout(config.hidden_dropout_prob)
54
+
55
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
56
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
57
+
58
+ self.add1 = Add()
59
+ self.add2 = Add()
60
+
61
+ def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
62
+ if input_ids is not None:
63
+ input_shape = input_ids.size()
64
+ else:
65
+ input_shape = inputs_embeds.size()[:-1]
66
+
67
+ seq_length = input_shape[1]
68
+
69
+ if position_ids is None:
70
+ position_ids = self.position_ids[:, :seq_length]
71
+
72
+ if token_type_ids is None:
73
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
74
+
75
+ if inputs_embeds is None:
76
+ inputs_embeds = self.word_embeddings(input_ids)
77
+ position_embeddings = self.position_embeddings(position_ids)
78
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
79
+
80
+ # embeddings = inputs_embeds + position_embeddings + token_type_embeddings
81
+ embeddings = self.add1([token_type_embeddings, position_embeddings])
82
+ embeddings = self.add2([embeddings, inputs_embeds])
83
+ embeddings = self.LayerNorm(embeddings)
84
+ embeddings = self.dropout(embeddings)
85
+ return embeddings
86
+
87
+ def relprop(self, cam, **kwargs):
88
+ cam = self.dropout.relprop(cam, **kwargs)
89
+ cam = self.LayerNorm.relprop(cam, **kwargs)
90
+
91
+ # [inputs_embeds, position_embeddings, token_type_embeddings]
92
+ (cam) = self.add2.relprop(cam, **kwargs)
93
+
94
+ return cam
95
+
96
+ class BertEncoder(nn.Module):
97
+ def __init__(self, config):
98
+ super().__init__()
99
+ self.config = config
100
+ self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
101
+
102
+ def forward(
103
+ self,
104
+ hidden_states,
105
+ attention_mask=None,
106
+ head_mask=None,
107
+ encoder_hidden_states=None,
108
+ encoder_attention_mask=None,
109
+ output_attentions=False,
110
+ output_hidden_states=False,
111
+ return_dict=False,
112
+ ):
113
+ all_hidden_states = () if output_hidden_states else None
114
+ all_attentions = () if output_attentions else None
115
+ for i, layer_module in enumerate(self.layer):
116
+ if output_hidden_states:
117
+ all_hidden_states = all_hidden_states + (hidden_states,)
118
+
119
+ layer_head_mask = head_mask[i] if head_mask is not None else None
120
+
121
+ if getattr(self.config, "gradient_checkpointing", False):
122
+
123
+ def create_custom_forward(module):
124
+ def custom_forward(*inputs):
125
+ return module(*inputs, output_attentions)
126
+
127
+ return custom_forward
128
+
129
+ layer_outputs = torch.utils.checkpoint.checkpoint(
130
+ create_custom_forward(layer_module),
131
+ hidden_states,
132
+ attention_mask,
133
+ layer_head_mask,
134
+ )
135
+ else:
136
+ layer_outputs = layer_module(
137
+ hidden_states,
138
+ attention_mask,
139
+ layer_head_mask,
140
+ output_attentions,
141
+ )
142
+ hidden_states = layer_outputs[0]
143
+ if output_attentions:
144
+ all_attentions = all_attentions + (layer_outputs[1],)
145
+
146
+ if output_hidden_states:
147
+ all_hidden_states = all_hidden_states + (hidden_states,)
148
+
149
+ if not return_dict:
150
+ return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
151
+ return BaseModelOutput(
152
+ last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
153
+ )
154
+
155
+ def relprop(self, cam, **kwargs):
156
+ # assuming output_hidden_states is False
157
+ for layer_module in reversed(self.layer):
158
+ cam = layer_module.relprop(cam, **kwargs)
159
+ return cam
160
+
161
+ # not adding relprop since this is only pooling at the end of the network, does not impact tokens importance
162
+ class BertPooler(nn.Module):
163
+ def __init__(self, config):
164
+ super().__init__()
165
+ self.dense = Linear(config.hidden_size, config.hidden_size)
166
+ self.activation = Tanh()
167
+ self.pool = IndexSelect()
168
+
169
+ def forward(self, hidden_states):
170
+ # We "pool" the model by simply taking the hidden state corresponding
171
+ # to the first token.
172
+ self._seq_size = hidden_states.shape[1]
173
+
174
+ # first_token_tensor = hidden_states[:, 0]
175
+ first_token_tensor = self.pool(hidden_states, 1, torch.tensor(0, device=hidden_states.device))
176
+ first_token_tensor = first_token_tensor.squeeze(1)
177
+ pooled_output = self.dense(first_token_tensor)
178
+ pooled_output = self.activation(pooled_output)
179
+ return pooled_output
180
+
181
+ def relprop(self, cam, **kwargs):
182
+ cam = self.activation.relprop(cam, **kwargs)
183
+ #print(cam.sum())
184
+ cam = self.dense.relprop(cam, **kwargs)
185
+ #print(cam.sum())
186
+ cam = cam.unsqueeze(1)
187
+ cam = self.pool.relprop(cam, **kwargs)
188
+ #print(cam.sum())
189
+
190
+ return cam
191
+
192
+ class BertAttention(nn.Module):
193
+ def __init__(self, config):
194
+ super().__init__()
195
+ self.self = BertSelfAttention(config)
196
+ self.output = BertSelfOutput(config)
197
+ self.pruned_heads = set()
198
+ self.clone = Clone()
199
+
200
+ def prune_heads(self, heads):
201
+ if len(heads) == 0:
202
+ return
203
+ heads, index = find_pruneable_heads_and_indices(
204
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
205
+ )
206
+
207
+ # Prune linear layers
208
+ self.self.query = prune_linear_layer(self.self.query, index)
209
+ self.self.key = prune_linear_layer(self.self.key, index)
210
+ self.self.value = prune_linear_layer(self.self.value, index)
211
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
212
+
213
+ # Update hyper params and store pruned heads
214
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
215
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
216
+ self.pruned_heads = self.pruned_heads.union(heads)
217
+
218
+ def forward(
219
+ self,
220
+ hidden_states,
221
+ attention_mask=None,
222
+ head_mask=None,
223
+ encoder_hidden_states=None,
224
+ encoder_attention_mask=None,
225
+ output_attentions=False,
226
+ ):
227
+ h1, h2 = self.clone(hidden_states, 2)
228
+ self_outputs = self.self(
229
+ h1,
230
+ attention_mask,
231
+ head_mask,
232
+ encoder_hidden_states,
233
+ encoder_attention_mask,
234
+ output_attentions,
235
+ )
236
+ attention_output = self.output(self_outputs[0], h2)
237
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
238
+ return outputs
239
+
240
+ def relprop(self, cam, **kwargs):
241
+ # assuming that we don't ouput the attentions (outputs = (attention_output,)), self_outputs=(context_layer,)
242
+ (cam1, cam2) = self.output.relprop(cam, **kwargs)
243
+ #print(cam1.sum(), cam2.sum(), (cam1 + cam2).sum())
244
+ cam1 = self.self.relprop(cam1, **kwargs)
245
+ #print(cam1.sum(), cam2.sum(), (cam1 + cam2).sum())
246
+
247
+ return self.clone.relprop((cam1, cam2), **kwargs)
248
+
249
+ class BertSelfAttention(nn.Module):
250
+ def __init__(self, config):
251
+ super().__init__()
252
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
253
+ raise ValueError(
254
+ "The hidden size (%d) is not a multiple of the number of attention "
255
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
256
+ )
257
+
258
+ self.num_attention_heads = config.num_attention_heads
259
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
260
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
261
+
262
+ self.query = Linear(config.hidden_size, self.all_head_size)
263
+ self.key = Linear(config.hidden_size, self.all_head_size)
264
+ self.value = Linear(config.hidden_size, self.all_head_size)
265
+
266
+ self.dropout = Dropout(config.attention_probs_dropout_prob)
267
+
268
+ self.matmul1 = MatMul()
269
+ self.matmul2 = MatMul()
270
+ self.softmax = Softmax(dim=-1)
271
+ self.add = Add()
272
+ self.mul = Mul()
273
+ self.head_mask = None
274
+ self.attention_mask = None
275
+ self.clone = Clone()
276
+
277
+ self.attn_cam = None
278
+ self.attn = None
279
+ self.attn_gradients = None
280
+
281
+ def get_attn(self):
282
+ return self.attn
283
+
284
+ def save_attn(self, attn):
285
+ self.attn = attn
286
+
287
+ def save_attn_cam(self, cam):
288
+ self.attn_cam = cam
289
+
290
+ def get_attn_cam(self):
291
+ return self.attn_cam
292
+
293
+ def save_attn_gradients(self, attn_gradients):
294
+ self.attn_gradients = attn_gradients
295
+
296
+ def get_attn_gradients(self):
297
+ return self.attn_gradients
298
+
299
+ def transpose_for_scores(self, x):
300
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
301
+ x = x.view(*new_x_shape)
302
+ return x.permute(0, 2, 1, 3)
303
+
304
+ def transpose_for_scores_relprop(self, x):
305
+ return x.permute(0, 2, 1, 3).flatten(2)
306
+
307
+ def forward(
308
+ self,
309
+ hidden_states,
310
+ attention_mask=None,
311
+ head_mask=None,
312
+ encoder_hidden_states=None,
313
+ encoder_attention_mask=None,
314
+ output_attentions=False,
315
+ ):
316
+ self.head_mask = head_mask
317
+ self.attention_mask = attention_mask
318
+
319
+ h1, h2, h3 = self.clone(hidden_states, 3)
320
+ mixed_query_layer = self.query(h1)
321
+
322
+ # If this is instantiated as a cross-attention module, the keys
323
+ # and values come from an encoder; the attention mask needs to be
324
+ # such that the encoder's padding tokens are not attended to.
325
+ if encoder_hidden_states is not None:
326
+ mixed_key_layer = self.key(encoder_hidden_states)
327
+ mixed_value_layer = self.value(encoder_hidden_states)
328
+ attention_mask = encoder_attention_mask
329
+ else:
330
+ mixed_key_layer = self.key(h2)
331
+ mixed_value_layer = self.value(h3)
332
+
333
+ query_layer = self.transpose_for_scores(mixed_query_layer)
334
+ key_layer = self.transpose_for_scores(mixed_key_layer)
335
+ value_layer = self.transpose_for_scores(mixed_value_layer)
336
+
337
+ # Take the dot product between "query" and "key" to get the raw attention scores.
338
+ attention_scores = self.matmul1([query_layer, key_layer.transpose(-1, -2)])
339
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
340
+ if attention_mask is not None:
341
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
342
+ attention_scores = self.add([attention_scores, attention_mask])
343
+
344
+ # Normalize the attention scores to probabilities.
345
+ attention_probs = self.softmax(attention_scores)
346
+
347
+ self.save_attn(attention_probs)
348
+ attention_probs.register_hook(self.save_attn_gradients)
349
+
350
+ # This is actually dropping out entire tokens to attend to, which might
351
+ # seem a bit unusual, but is taken from the original Transformer paper.
352
+ attention_probs = self.dropout(attention_probs)
353
+
354
+ # Mask heads if we want to
355
+ if head_mask is not None:
356
+ attention_probs = attention_probs * head_mask
357
+
358
+ context_layer = self.matmul2([attention_probs, value_layer])
359
+
360
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
361
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
362
+ context_layer = context_layer.view(*new_context_layer_shape)
363
+
364
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
365
+ return outputs
366
+
367
+ def relprop(self, cam, **kwargs):
368
+ # Assume output_attentions == False
369
+ cam = self.transpose_for_scores(cam)
370
+
371
+ # [attention_probs, value_layer]
372
+ (cam1, cam2) = self.matmul2.relprop(cam, **kwargs)
373
+ cam1 /= 2
374
+ cam2 /= 2
375
+ if self.head_mask is not None:
376
+ # [attention_probs, head_mask]
377
+ (cam1, _)= self.mul.relprop(cam1, **kwargs)
378
+
379
+
380
+ self.save_attn_cam(cam1)
381
+
382
+ cam1 = self.dropout.relprop(cam1, **kwargs)
383
+
384
+ cam1 = self.softmax.relprop(cam1, **kwargs)
385
+
386
+ if self.attention_mask is not None:
387
+ # [attention_scores, attention_mask]
388
+ (cam1, _) = self.add.relprop(cam1, **kwargs)
389
+
390
+ # [query_layer, key_layer.transpose(-1, -2)]
391
+ (cam1_1, cam1_2) = self.matmul1.relprop(cam1, **kwargs)
392
+ cam1_1 /= 2
393
+ cam1_2 /= 2
394
+
395
+ # query
396
+ cam1_1 = self.transpose_for_scores_relprop(cam1_1)
397
+ cam1_1 = self.query.relprop(cam1_1, **kwargs)
398
+
399
+ # key
400
+ cam1_2 = self.transpose_for_scores_relprop(cam1_2.transpose(-1, -2))
401
+ cam1_2 = self.key.relprop(cam1_2, **kwargs)
402
+
403
+ # value
404
+ cam2 = self.transpose_for_scores_relprop(cam2)
405
+ cam2 = self.value.relprop(cam2, **kwargs)
406
+
407
+ cam = self.clone.relprop((cam1_1, cam1_2, cam2), **kwargs)
408
+
409
+ return cam
410
+
411
+
412
+ class BertSelfOutput(nn.Module):
413
+ def __init__(self, config):
414
+ super().__init__()
415
+ self.dense = Linear(config.hidden_size, config.hidden_size)
416
+ self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
417
+ self.dropout = Dropout(config.hidden_dropout_prob)
418
+ self.add = Add()
419
+
420
+ def forward(self, hidden_states, input_tensor):
421
+ hidden_states = self.dense(hidden_states)
422
+ hidden_states = self.dropout(hidden_states)
423
+ add = self.add([hidden_states, input_tensor])
424
+ hidden_states = self.LayerNorm(add)
425
+ return hidden_states
426
+
427
+ def relprop(self, cam, **kwargs):
428
+ cam = self.LayerNorm.relprop(cam, **kwargs)
429
+ # [hidden_states, input_tensor]
430
+ (cam1, cam2) = self.add.relprop(cam, **kwargs)
431
+ cam1 = self.dropout.relprop(cam1, **kwargs)
432
+ cam1 = self.dense.relprop(cam1, **kwargs)
433
+
434
+ return (cam1, cam2)
435
+
436
+
437
+ class BertIntermediate(nn.Module):
438
+ def __init__(self, config):
439
+ super().__init__()
440
+ self.dense = Linear(config.hidden_size, config.intermediate_size)
441
+ if isinstance(config.hidden_act, str):
442
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]()
443
+ else:
444
+ self.intermediate_act_fn = config.hidden_act
445
+
446
+ def forward(self, hidden_states):
447
+ hidden_states = self.dense(hidden_states)
448
+ hidden_states = self.intermediate_act_fn(hidden_states)
449
+ return hidden_states
450
+
451
+ def relprop(self, cam, **kwargs):
452
+ cam = self.intermediate_act_fn.relprop(cam, **kwargs) # FIXME only ReLU
453
+ #print(cam.sum())
454
+ cam = self.dense.relprop(cam, **kwargs)
455
+ #print(cam.sum())
456
+ return cam
457
+
458
+
459
+ class BertOutput(nn.Module):
460
+ def __init__(self, config):
461
+ super().__init__()
462
+ self.dense = Linear(config.intermediate_size, config.hidden_size)
463
+ self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
464
+ self.dropout = Dropout(config.hidden_dropout_prob)
465
+ self.add = Add()
466
+
467
+ def forward(self, hidden_states, input_tensor):
468
+ hidden_states = self.dense(hidden_states)
469
+ hidden_states = self.dropout(hidden_states)
470
+ add = self.add([hidden_states, input_tensor])
471
+ hidden_states = self.LayerNorm(add)
472
+ return hidden_states
473
+
474
+ def relprop(self, cam, **kwargs):
475
+ # print("in", cam.sum())
476
+ cam = self.LayerNorm.relprop(cam, **kwargs)
477
+ #print(cam.sum())
478
+ # [hidden_states, input_tensor]
479
+ (cam1, cam2)= self.add.relprop(cam, **kwargs)
480
+ # print("add", cam1.sum(), cam2.sum(), cam1.sum() + cam2.sum())
481
+ cam1 = self.dropout.relprop(cam1, **kwargs)
482
+ #print(cam1.sum())
483
+ cam1 = self.dense.relprop(cam1, **kwargs)
484
+ # print("dense", cam1.sum())
485
+
486
+ # print("out", cam1.sum() + cam2.sum(), cam1.sum(), cam2.sum())
487
+ return (cam1, cam2)
488
+
489
+
490
+ class BertLayer(nn.Module):
491
+ def __init__(self, config):
492
+ super().__init__()
493
+ self.attention = BertAttention(config)
494
+ self.intermediate = BertIntermediate(config)
495
+ self.output = BertOutput(config)
496
+ self.clone = Clone()
497
+
498
+ def forward(
499
+ self,
500
+ hidden_states,
501
+ attention_mask=None,
502
+ head_mask=None,
503
+ output_attentions=False,
504
+ ):
505
+ self_attention_outputs = self.attention(
506
+ hidden_states,
507
+ attention_mask,
508
+ head_mask,
509
+ output_attentions=output_attentions,
510
+ )
511
+ attention_output = self_attention_outputs[0]
512
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
513
+
514
+ ao1, ao2 = self.clone(attention_output, 2)
515
+ intermediate_output = self.intermediate(ao1)
516
+ layer_output = self.output(intermediate_output, ao2)
517
+
518
+ outputs = (layer_output,) + outputs
519
+ return outputs
520
+
521
+ def relprop(self, cam, **kwargs):
522
+ (cam1, cam2) = self.output.relprop(cam, **kwargs)
523
+ # print("output", cam1.sum(), cam2.sum(), cam1.sum() + cam2.sum())
524
+ cam1 = self.intermediate.relprop(cam1, **kwargs)
525
+ # print("intermediate", cam1.sum())
526
+ cam = self.clone.relprop((cam1, cam2), **kwargs)
527
+ # print("clone", cam.sum())
528
+ cam = self.attention.relprop(cam, **kwargs)
529
+ # print("attention", cam.sum())
530
+ return cam
531
+
532
+
533
+ class BertModel(BertPreTrainedModel):
534
+ def __init__(self, config):
535
+ super().__init__(config)
536
+ self.config = config
537
+
538
+ self.embeddings = BertEmbeddings(config)
539
+ self.encoder = BertEncoder(config)
540
+ self.pooler = BertPooler(config)
541
+
542
+ self.init_weights()
543
+
544
+ def get_input_embeddings(self):
545
+ return self.embeddings.word_embeddings
546
+
547
+ def set_input_embeddings(self, value):
548
+ self.embeddings.word_embeddings = value
549
+
550
+ def forward(
551
+ self,
552
+ input_ids=None,
553
+ attention_mask=None,
554
+ token_type_ids=None,
555
+ position_ids=None,
556
+ head_mask=None,
557
+ inputs_embeds=None,
558
+ encoder_hidden_states=None,
559
+ encoder_attention_mask=None,
560
+ output_attentions=None,
561
+ output_hidden_states=None,
562
+ return_dict=None,
563
+ ):
564
+ r"""
565
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
566
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
567
+ if the model is configured as a decoder.
568
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
569
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask
570
+ is used in the cross-attention if the model is configured as a decoder.
571
+ Mask values selected in ``[0, 1]``:
572
+ ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
573
+ """
574
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
575
+ output_hidden_states = (
576
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
577
+ )
578
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
579
+
580
+ if input_ids is not None and inputs_embeds is not None:
581
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
582
+ elif input_ids is not None:
583
+ input_shape = input_ids.size()
584
+ elif inputs_embeds is not None:
585
+ input_shape = inputs_embeds.size()[:-1]
586
+ else:
587
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
588
+
589
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
590
+
591
+ if attention_mask is None:
592
+ attention_mask = torch.ones(input_shape, device=device)
593
+ if token_type_ids is None:
594
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
595
+
596
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
597
+ # ourselves in which case we just need to make it broadcastable to all heads.
598
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
599
+
600
+ # If a 2D or 3D attention mask is provided for the cross-attention
601
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
602
+ if self.config.is_decoder and encoder_hidden_states is not None:
603
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
604
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
605
+ if encoder_attention_mask is None:
606
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
607
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
608
+ else:
609
+ encoder_extended_attention_mask = None
610
+
611
+ # Prepare head mask if needed
612
+ # 1.0 in head_mask indicate we keep the head
613
+ # attention_probs has shape bsz x n_heads x N x N
614
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
615
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
616
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
617
+
618
+ embedding_output = self.embeddings(
619
+ input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
620
+ )
621
+
622
+ encoder_outputs = self.encoder(
623
+ embedding_output,
624
+ attention_mask=extended_attention_mask,
625
+ head_mask=head_mask,
626
+ encoder_hidden_states=encoder_hidden_states,
627
+ encoder_attention_mask=encoder_extended_attention_mask,
628
+ output_attentions=output_attentions,
629
+ output_hidden_states=output_hidden_states,
630
+ return_dict=return_dict,
631
+ )
632
+ sequence_output = encoder_outputs[0]
633
+ pooled_output = self.pooler(sequence_output)
634
+
635
+ if not return_dict:
636
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
637
+
638
+ return BaseModelOutputWithPooling(
639
+ last_hidden_state=sequence_output,
640
+ pooler_output=pooled_output,
641
+ hidden_states=encoder_outputs.hidden_states,
642
+ attentions=encoder_outputs.attentions,
643
+ )
644
+
645
+ def relprop(self, cam, **kwargs):
646
+ cam = self.pooler.relprop(cam, **kwargs)
647
+ # print("111111111111",cam.sum())
648
+ cam = self.encoder.relprop(cam, **kwargs)
649
+ # print("222222222222222", cam.sum())
650
+ # print("conservation: ", cam.sum())
651
+ return cam
652
+
653
+
654
+ if __name__ == '__main__':
655
+ class Config:
656
+ def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob):
657
+ self.hidden_size = hidden_size
658
+ self.num_attention_heads = num_attention_heads
659
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
660
+
661
+ model = BertSelfAttention(Config(1024, 4, 0.1))
662
+ x = torch.rand(2, 20, 1024)
663
+ x.requires_grad_()
664
+
665
+ model.eval()
666
+
667
+ y = model.forward(x)
668
+
669
+ relprop = model.relprop(torch.rand(2, 20, 1024), (torch.rand(2, 20, 1024),))
670
+
671
+ print(relprop[1][0].shape)
BERT_explainability/BERT_cls_lrp.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertPreTrainedModel
2
+ from transformers.utils import logging
3
+ from BERT_explainability.modules.layers_lrp import *
4
+ from BERT_explainability.modules.BERT.BERT_orig_lrp import BertModel
5
+ from torch.nn import CrossEntropyLoss, MSELoss
6
+ import torch.nn as nn
7
+ from typing import List, Any
8
+ import torch
9
+ from BERT_rationale_benchmark.models.model_utils import PaddedSequence
10
+
11
+
12
+ class BertForSequenceClassification(BertPreTrainedModel):
13
+ def __init__(self, config):
14
+ super().__init__(config)
15
+ self.num_labels = config.num_labels
16
+
17
+ self.bert = BertModel(config)
18
+ self.dropout = Dropout(config.hidden_dropout_prob)
19
+ self.classifier = Linear(config.hidden_size, config.num_labels)
20
+
21
+ self.init_weights()
22
+
23
+ def forward(
24
+ self,
25
+ input_ids=None,
26
+ attention_mask=None,
27
+ token_type_ids=None,
28
+ position_ids=None,
29
+ head_mask=None,
30
+ inputs_embeds=None,
31
+ labels=None,
32
+ output_attentions=None,
33
+ output_hidden_states=None,
34
+ return_dict=None,
35
+ ):
36
+ r"""
37
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
38
+ Labels for computing the sequence classification/regression loss.
39
+ Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
40
+ If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
41
+ If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
42
+ """
43
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
44
+
45
+ outputs = self.bert(
46
+ input_ids,
47
+ attention_mask=attention_mask,
48
+ token_type_ids=token_type_ids,
49
+ position_ids=position_ids,
50
+ head_mask=head_mask,
51
+ inputs_embeds=inputs_embeds,
52
+ output_attentions=output_attentions,
53
+ output_hidden_states=output_hidden_states,
54
+ return_dict=return_dict,
55
+ )
56
+
57
+ pooled_output = outputs[1]
58
+
59
+ pooled_output = self.dropout(pooled_output)
60
+ logits = self.classifier(pooled_output)
61
+
62
+ loss = None
63
+ if labels is not None:
64
+ if self.num_labels == 1:
65
+ # We are doing regression
66
+ loss_fct = MSELoss()
67
+ loss = loss_fct(logits.view(-1), labels.view(-1))
68
+ else:
69
+ loss_fct = CrossEntropyLoss()
70
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
71
+
72
+ if not return_dict:
73
+ output = (logits,) + outputs[2:]
74
+ return ((loss,) + output) if loss is not None else output
75
+
76
+ return SequenceClassifierOutput(
77
+ loss=loss,
78
+ logits=logits,
79
+ hidden_states=outputs.hidden_states,
80
+ attentions=outputs.attentions,
81
+ )
82
+
83
+ def relprop(self, cam=None, **kwargs):
84
+ cam = self.classifier.relprop(cam, **kwargs)
85
+ cam = self.dropout.relprop(cam, **kwargs)
86
+ cam = self.bert.relprop(cam, **kwargs)
87
+ return cam
88
+
89
+
90
+ # this is the actual classifier we will be using
91
+ class BertClassifier(nn.Module):
92
+ """Thin wrapper around BertForSequenceClassification"""
93
+
94
+ def __init__(self,
95
+ bert_dir: str,
96
+ pad_token_id: int,
97
+ cls_token_id: int,
98
+ sep_token_id: int,
99
+ num_labels: int,
100
+ max_length: int = 512,
101
+ use_half_precision=True):
102
+ super(BertClassifier, self).__init__()
103
+ bert = BertForSequenceClassification.from_pretrained(bert_dir, num_labels=num_labels)
104
+ if use_half_precision:
105
+ import apex
106
+ bert = bert.half()
107
+ self.bert = bert
108
+ self.pad_token_id = pad_token_id
109
+ self.cls_token_id = cls_token_id
110
+ self.sep_token_id = sep_token_id
111
+ self.max_length = max_length
112
+
113
+ def forward(self,
114
+ query: List[torch.tensor],
115
+ docids: List[Any],
116
+ document_batch: List[torch.tensor]):
117
+ assert len(query) == len(document_batch)
118
+ print(query)
119
+ # note about device management:
120
+ # since distributed training is enabled, the inputs to this module can be on *any* device (preferably cpu, since we wrap and unwrap the module)
121
+ # we want to keep these params on the input device (assuming CPU) for as long as possible for cheap memory access
122
+ target_device = next(self.parameters()).device
123
+ cls_token = torch.tensor([self.cls_token_id]).to(device=document_batch[0].device)
124
+ sep_token = torch.tensor([self.sep_token_id]).to(device=document_batch[0].device)
125
+ input_tensors = []
126
+ position_ids = []
127
+ for q, d in zip(query, document_batch):
128
+ if len(q) + len(d) + 2 > self.max_length:
129
+ d = d[:(self.max_length - len(q) - 2)]
130
+ input_tensors.append(torch.cat([cls_token, q, sep_token, d]))
131
+ position_ids.append(torch.tensor(list(range(0, len(q) + 1)) + list(range(0, len(d) + 1))))
132
+ bert_input = PaddedSequence.autopad(input_tensors, batch_first=True, padding_value=self.pad_token_id,
133
+ device=target_device)
134
+ positions = PaddedSequence.autopad(position_ids, batch_first=True, padding_value=0, device=target_device)
135
+ (classes,) = self.bert(bert_input.data,
136
+ attention_mask=bert_input.mask(on=0.0, off=float('-inf'), device=target_device),
137
+ position_ids=positions.data)
138
+ assert torch.all(classes == classes) # for nans
139
+
140
+ print(input_tensors[0])
141
+ print(self.relprop()[0])
142
+
143
+ return classes
144
+
145
+ def relprop(self, cam=None, **kwargs):
146
+ return self.bert.relprop(cam, **kwargs)
147
+
148
+
149
+ if __name__ == '__main__':
150
+ from transformers import BertTokenizer
151
+ import os
152
+
153
+ class Config:
154
+ def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, num_labels,
155
+ hidden_dropout_prob):
156
+ self.hidden_size = hidden_size
157
+ self.num_attention_heads = num_attention_heads
158
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
159
+ self.num_labels = num_labels
160
+ self.hidden_dropout_prob = hidden_dropout_prob
161
+
162
+
163
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
164
+ x = tokenizer.encode_plus("In this movie the acting is great. The movie is perfect! [sep]",
165
+ add_special_tokens=True,
166
+ max_length=512,
167
+ return_token_type_ids=False,
168
+ return_attention_mask=True,
169
+ pad_to_max_length=True,
170
+ return_tensors='pt',
171
+ truncation=True)
172
+
173
+ print(x['input_ids'])
174
+
175
+ model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
176
+ model_save_file = os.path.join('./BERT_explainability/output_bert/movies/classifier/', 'classifier.pt')
177
+ model.load_state_dict(torch.load(model_save_file))
178
+
179
+ # x = torch.randint(100, (2, 20))
180
+ # x = torch.tensor([[101, 2054, 2003, 1996, 15792, 1997, 2023, 3319, 1029, 102,
181
+ # 101, 4079, 102, 101, 6732, 102, 101, 2643, 102, 101,
182
+ # 2038, 102, 101, 1037, 102, 101, 2933, 102, 101, 2005,
183
+ # 102, 101, 2032, 102, 101, 1010, 102, 101, 1037, 102,
184
+ # 101, 3800, 102, 101, 2005, 102, 101, 2010, 102, 101,
185
+ # 2166, 102, 101, 1010, 102, 101, 1998, 102, 101, 2010,
186
+ # 102, 101, 4650, 102, 101, 1010, 102, 101, 2002, 102,
187
+ # 101, 2074, 102, 101, 2515, 102, 101, 1050, 102, 101,
188
+ # 1005, 102, 101, 1056, 102, 101, 2113, 102, 101, 2054,
189
+ # 102, 101, 1012, 102]])
190
+ # x.requires_grad_()
191
+
192
+ model.eval()
193
+
194
+ y = model(x['input_ids'], x['attention_mask'])
195
+ print(y)
196
+
197
+ cam, _ = model.relprop()
198
+
199
+ #print(cam.shape)
200
+
201
+ cam = cam.sum(-1)
202
+ #print(cam)
BERT_explainability/BERT_orig_lrp.py ADDED
@@ -0,0 +1,671 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+
3
+ import torch
4
+ from torch import nn
5
+ import torch.nn.functional as F
6
+ import math
7
+ from transformers import BertConfig
8
+ from transformers.modeling_outputs import BaseModelOutputWithPooling, BaseModelOutput
9
+ from BERT_explainability.modules.layers_lrp import *
10
+ from transformers import (
11
+ BertPreTrainedModel,
12
+ PreTrainedModel,
13
+ )
14
+
15
+ ACT2FN = {
16
+ "relu": ReLU,
17
+ "tanh": Tanh,
18
+ "gelu": GELU,
19
+ }
20
+
21
+
22
+ def get_activation(activation_string):
23
+ if activation_string in ACT2FN:
24
+ return ACT2FN[activation_string]
25
+ else:
26
+ raise KeyError("function {} not found in ACT2FN mapping {}".format(activation_string, list(ACT2FN.keys())))
27
+
28
+ def compute_rollout_attention(all_layer_matrices, start_layer=0):
29
+ # adding residual consideration
30
+ num_tokens = all_layer_matrices[0].shape[1]
31
+ batch_size = all_layer_matrices[0].shape[0]
32
+ eye = torch.eye(num_tokens).expand(batch_size, num_tokens, num_tokens).to(all_layer_matrices[0].device)
33
+ all_layer_matrices = [all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))]
34
+ all_layer_matrices = [all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True)
35
+ for i in range(len(all_layer_matrices))]
36
+ joint_attention = all_layer_matrices[start_layer]
37
+ for i in range(start_layer+1, len(all_layer_matrices)):
38
+ joint_attention = all_layer_matrices[i].bmm(joint_attention)
39
+ return joint_attention
40
+
41
+ class BertEmbeddings(nn.Module):
42
+ """Construct the embeddings from word, position and token_type embeddings."""
43
+
44
+ def __init__(self, config):
45
+ super().__init__()
46
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
47
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
48
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
49
+
50
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
51
+ # any TensorFlow checkpoint file
52
+ self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
53
+ self.dropout = Dropout(config.hidden_dropout_prob)
54
+
55
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
56
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
57
+
58
+ self.add1 = Add()
59
+ self.add2 = Add()
60
+
61
+ def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
62
+ if input_ids is not None:
63
+ input_shape = input_ids.size()
64
+ else:
65
+ input_shape = inputs_embeds.size()[:-1]
66
+
67
+ seq_length = input_shape[1]
68
+
69
+ if position_ids is None:
70
+ position_ids = self.position_ids[:, :seq_length]
71
+
72
+ if token_type_ids is None:
73
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
74
+
75
+ if inputs_embeds is None:
76
+ inputs_embeds = self.word_embeddings(input_ids)
77
+ position_embeddings = self.position_embeddings(position_ids)
78
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
79
+
80
+ # embeddings = inputs_embeds + position_embeddings + token_type_embeddings
81
+ embeddings = self.add1([token_type_embeddings, position_embeddings])
82
+ embeddings = self.add2([embeddings, inputs_embeds])
83
+ embeddings = self.LayerNorm(embeddings)
84
+ embeddings = self.dropout(embeddings)
85
+ return embeddings
86
+
87
+ def relprop(self, cam, **kwargs):
88
+ cam = self.dropout.relprop(cam, **kwargs)
89
+ cam = self.LayerNorm.relprop(cam, **kwargs)
90
+
91
+ # [inputs_embeds, position_embeddings, token_type_embeddings]
92
+ (cam) = self.add2.relprop(cam, **kwargs)
93
+
94
+ return cam
95
+
96
+ class BertEncoder(nn.Module):
97
+ def __init__(self, config):
98
+ super().__init__()
99
+ self.config = config
100
+ self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
101
+
102
+ def forward(
103
+ self,
104
+ hidden_states,
105
+ attention_mask=None,
106
+ head_mask=None,
107
+ encoder_hidden_states=None,
108
+ encoder_attention_mask=None,
109
+ output_attentions=False,
110
+ output_hidden_states=False,
111
+ return_dict=False,
112
+ ):
113
+ all_hidden_states = () if output_hidden_states else None
114
+ all_attentions = () if output_attentions else None
115
+ for i, layer_module in enumerate(self.layer):
116
+ if output_hidden_states:
117
+ all_hidden_states = all_hidden_states + (hidden_states,)
118
+
119
+ layer_head_mask = head_mask[i] if head_mask is not None else None
120
+
121
+ if getattr(self.config, "gradient_checkpointing", False):
122
+
123
+ def create_custom_forward(module):
124
+ def custom_forward(*inputs):
125
+ return module(*inputs, output_attentions)
126
+
127
+ return custom_forward
128
+
129
+ layer_outputs = torch.utils.checkpoint.checkpoint(
130
+ create_custom_forward(layer_module),
131
+ hidden_states,
132
+ attention_mask,
133
+ layer_head_mask,
134
+ )
135
+ else:
136
+ layer_outputs = layer_module(
137
+ hidden_states,
138
+ attention_mask,
139
+ layer_head_mask,
140
+ output_attentions,
141
+ )
142
+ hidden_states = layer_outputs[0]
143
+ if output_attentions:
144
+ all_attentions = all_attentions + (layer_outputs[1],)
145
+
146
+ if output_hidden_states:
147
+ all_hidden_states = all_hidden_states + (hidden_states,)
148
+
149
+ if not return_dict:
150
+ return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
151
+ return BaseModelOutput(
152
+ last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
153
+ )
154
+
155
+ def relprop(self, cam, **kwargs):
156
+ # assuming output_hidden_states is False
157
+ for layer_module in reversed(self.layer):
158
+ cam = layer_module.relprop(cam, **kwargs)
159
+ return cam
160
+
161
+ # not adding relprop since this is only pooling at the end of the network, does not impact tokens importance
162
+ class BertPooler(nn.Module):
163
+ def __init__(self, config):
164
+ super().__init__()
165
+ self.dense = Linear(config.hidden_size, config.hidden_size)
166
+ self.activation = Tanh()
167
+ self.pool = IndexSelect()
168
+
169
+ def forward(self, hidden_states):
170
+ # We "pool" the model by simply taking the hidden state corresponding
171
+ # to the first token.
172
+ self._seq_size = hidden_states.shape[1]
173
+
174
+ # first_token_tensor = hidden_states[:, 0]
175
+ first_token_tensor = self.pool(hidden_states, 1, torch.tensor(0, device=hidden_states.device))
176
+ first_token_tensor = first_token_tensor.squeeze(1)
177
+ pooled_output = self.dense(first_token_tensor)
178
+ pooled_output = self.activation(pooled_output)
179
+ return pooled_output
180
+
181
+ def relprop(self, cam, **kwargs):
182
+ cam = self.activation.relprop(cam, **kwargs)
183
+ #print(cam.sum())
184
+ cam = self.dense.relprop(cam, **kwargs)
185
+ #print(cam.sum())
186
+ cam = cam.unsqueeze(1)
187
+ cam = self.pool.relprop(cam, **kwargs)
188
+ #print(cam.sum())
189
+
190
+ return cam
191
+
192
+ class BertAttention(nn.Module):
193
+ def __init__(self, config):
194
+ super().__init__()
195
+ self.self = BertSelfAttention(config)
196
+ self.output = BertSelfOutput(config)
197
+ self.pruned_heads = set()
198
+ self.clone = Clone()
199
+
200
+ def prune_heads(self, heads):
201
+ if len(heads) == 0:
202
+ return
203
+ heads, index = find_pruneable_heads_and_indices(
204
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
205
+ )
206
+
207
+ # Prune linear layers
208
+ self.self.query = prune_linear_layer(self.self.query, index)
209
+ self.self.key = prune_linear_layer(self.self.key, index)
210
+ self.self.value = prune_linear_layer(self.self.value, index)
211
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
212
+
213
+ # Update hyper params and store pruned heads
214
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
215
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
216
+ self.pruned_heads = self.pruned_heads.union(heads)
217
+
218
+ def forward(
219
+ self,
220
+ hidden_states,
221
+ attention_mask=None,
222
+ head_mask=None,
223
+ encoder_hidden_states=None,
224
+ encoder_attention_mask=None,
225
+ output_attentions=False,
226
+ ):
227
+ h1, h2 = self.clone(hidden_states, 2)
228
+ self_outputs = self.self(
229
+ h1,
230
+ attention_mask,
231
+ head_mask,
232
+ encoder_hidden_states,
233
+ encoder_attention_mask,
234
+ output_attentions,
235
+ )
236
+ attention_output = self.output(self_outputs[0], h2)
237
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
238
+ return outputs
239
+
240
+ def relprop(self, cam, **kwargs):
241
+ # assuming that we don't ouput the attentions (outputs = (attention_output,)), self_outputs=(context_layer,)
242
+ (cam1, cam2) = self.output.relprop(cam, **kwargs)
243
+ #print(cam1.sum(), cam2.sum(), (cam1 + cam2).sum())
244
+ cam1 = self.self.relprop(cam1, **kwargs)
245
+ #print(cam1.sum(), cam2.sum(), (cam1 + cam2).sum())
246
+
247
+ return self.clone.relprop((cam1, cam2), **kwargs)
248
+
249
+ class BertSelfAttention(nn.Module):
250
+ def __init__(self, config):
251
+ super().__init__()
252
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
253
+ raise ValueError(
254
+ "The hidden size (%d) is not a multiple of the number of attention "
255
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
256
+ )
257
+
258
+ self.num_attention_heads = config.num_attention_heads
259
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
260
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
261
+
262
+ self.query = Linear(config.hidden_size, self.all_head_size)
263
+ self.key = Linear(config.hidden_size, self.all_head_size)
264
+ self.value = Linear(config.hidden_size, self.all_head_size)
265
+
266
+ self.dropout = Dropout(config.attention_probs_dropout_prob)
267
+
268
+ self.matmul1 = MatMul()
269
+ self.matmul2 = MatMul()
270
+ self.softmax = Softmax(dim=-1)
271
+ self.add = Add()
272
+ self.mul = Mul()
273
+ self.head_mask = None
274
+ self.attention_mask = None
275
+ self.clone = Clone()
276
+
277
+ self.attn_cam = None
278
+ self.attn = None
279
+ self.attn_gradients = None
280
+
281
+ def get_attn(self):
282
+ return self.attn
283
+
284
+ def save_attn(self, attn):
285
+ self.attn = attn
286
+
287
+ def save_attn_cam(self, cam):
288
+ self.attn_cam = cam
289
+
290
+ def get_attn_cam(self):
291
+ return self.attn_cam
292
+
293
+ def save_attn_gradients(self, attn_gradients):
294
+ self.attn_gradients = attn_gradients
295
+
296
+ def get_attn_gradients(self):
297
+ return self.attn_gradients
298
+
299
+ def transpose_for_scores(self, x):
300
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
301
+ x = x.view(*new_x_shape)
302
+ return x.permute(0, 2, 1, 3)
303
+
304
+ def transpose_for_scores_relprop(self, x):
305
+ return x.permute(0, 2, 1, 3).flatten(2)
306
+
307
+ def forward(
308
+ self,
309
+ hidden_states,
310
+ attention_mask=None,
311
+ head_mask=None,
312
+ encoder_hidden_states=None,
313
+ encoder_attention_mask=None,
314
+ output_attentions=False,
315
+ ):
316
+ self.head_mask = head_mask
317
+ self.attention_mask = attention_mask
318
+
319
+ h1, h2, h3 = self.clone(hidden_states, 3)
320
+ mixed_query_layer = self.query(h1)
321
+
322
+ # If this is instantiated as a cross-attention module, the keys
323
+ # and values come from an encoder; the attention mask needs to be
324
+ # such that the encoder's padding tokens are not attended to.
325
+ if encoder_hidden_states is not None:
326
+ mixed_key_layer = self.key(encoder_hidden_states)
327
+ mixed_value_layer = self.value(encoder_hidden_states)
328
+ attention_mask = encoder_attention_mask
329
+ else:
330
+ mixed_key_layer = self.key(h2)
331
+ mixed_value_layer = self.value(h3)
332
+
333
+ query_layer = self.transpose_for_scores(mixed_query_layer)
334
+ key_layer = self.transpose_for_scores(mixed_key_layer)
335
+ value_layer = self.transpose_for_scores(mixed_value_layer)
336
+
337
+ # Take the dot product between "query" and "key" to get the raw attention scores.
338
+ attention_scores = self.matmul1([query_layer, key_layer.transpose(-1, -2)])
339
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
340
+ if attention_mask is not None:
341
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
342
+ attention_scores = self.add([attention_scores, attention_mask])
343
+
344
+ # Normalize the attention scores to probabilities.
345
+ attention_probs = self.softmax(attention_scores)
346
+
347
+ self.save_attn(attention_probs)
348
+ attention_probs.register_hook(self.save_attn_gradients)
349
+
350
+ # This is actually dropping out entire tokens to attend to, which might
351
+ # seem a bit unusual, but is taken from the original Transformer paper.
352
+ attention_probs = self.dropout(attention_probs)
353
+
354
+ # Mask heads if we want to
355
+ if head_mask is not None:
356
+ attention_probs = attention_probs * head_mask
357
+
358
+ context_layer = self.matmul2([attention_probs, value_layer])
359
+
360
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
361
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
362
+ context_layer = context_layer.view(*new_context_layer_shape)
363
+
364
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
365
+ return outputs
366
+
367
+ def relprop(self, cam, **kwargs):
368
+ # Assume output_attentions == False
369
+ cam = self.transpose_for_scores(cam)
370
+
371
+ # [attention_probs, value_layer]
372
+ (cam1, cam2) = self.matmul2.relprop(cam, **kwargs)
373
+ cam1 /= 2
374
+ cam2 /= 2
375
+ if self.head_mask is not None:
376
+ # [attention_probs, head_mask]
377
+ (cam1, _)= self.mul.relprop(cam1, **kwargs)
378
+
379
+
380
+ self.save_attn_cam(cam1)
381
+
382
+ cam1 = self.dropout.relprop(cam1, **kwargs)
383
+
384
+ cam1 = self.softmax.relprop(cam1, **kwargs)
385
+
386
+ if self.attention_mask is not None:
387
+ # [attention_scores, attention_mask]
388
+ (cam1, _) = self.add.relprop(cam1, **kwargs)
389
+
390
+ # [query_layer, key_layer.transpose(-1, -2)]
391
+ (cam1_1, cam1_2) = self.matmul1.relprop(cam1, **kwargs)
392
+ cam1_1 /= 2
393
+ cam1_2 /= 2
394
+
395
+ # query
396
+ cam1_1 = self.transpose_for_scores_relprop(cam1_1)
397
+ cam1_1 = self.query.relprop(cam1_1, **kwargs)
398
+
399
+ # key
400
+ cam1_2 = self.transpose_for_scores_relprop(cam1_2.transpose(-1, -2))
401
+ cam1_2 = self.key.relprop(cam1_2, **kwargs)
402
+
403
+ # value
404
+ cam2 = self.transpose_for_scores_relprop(cam2)
405
+ cam2 = self.value.relprop(cam2, **kwargs)
406
+
407
+ cam = self.clone.relprop((cam1_1, cam1_2, cam2), **kwargs)
408
+
409
+ return cam
410
+
411
+
412
+ class BertSelfOutput(nn.Module):
413
+ def __init__(self, config):
414
+ super().__init__()
415
+ self.dense = Linear(config.hidden_size, config.hidden_size)
416
+ self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
417
+ self.dropout = Dropout(config.hidden_dropout_prob)
418
+ self.add = Add()
419
+
420
+ def forward(self, hidden_states, input_tensor):
421
+ hidden_states = self.dense(hidden_states)
422
+ hidden_states = self.dropout(hidden_states)
423
+ add = self.add([hidden_states, input_tensor])
424
+ hidden_states = self.LayerNorm(add)
425
+ return hidden_states
426
+
427
+ def relprop(self, cam, **kwargs):
428
+ cam = self.LayerNorm.relprop(cam, **kwargs)
429
+ # [hidden_states, input_tensor]
430
+ (cam1, cam2) = self.add.relprop(cam, **kwargs)
431
+ cam1 = self.dropout.relprop(cam1, **kwargs)
432
+ cam1 = self.dense.relprop(cam1, **kwargs)
433
+
434
+ return (cam1, cam2)
435
+
436
+
437
+ class BertIntermediate(nn.Module):
438
+ def __init__(self, config):
439
+ super().__init__()
440
+ self.dense = Linear(config.hidden_size, config.intermediate_size)
441
+ if isinstance(config.hidden_act, str):
442
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]()
443
+ else:
444
+ self.intermediate_act_fn = config.hidden_act
445
+
446
+ def forward(self, hidden_states):
447
+ hidden_states = self.dense(hidden_states)
448
+ hidden_states = self.intermediate_act_fn(hidden_states)
449
+ return hidden_states
450
+
451
+ def relprop(self, cam, **kwargs):
452
+ cam = self.intermediate_act_fn.relprop(cam, **kwargs) # FIXME only ReLU
453
+ #print(cam.sum())
454
+ cam = self.dense.relprop(cam, **kwargs)
455
+ #print(cam.sum())
456
+ return cam
457
+
458
+
459
+ class BertOutput(nn.Module):
460
+ def __init__(self, config):
461
+ super().__init__()
462
+ self.dense = Linear(config.intermediate_size, config.hidden_size)
463
+ self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
464
+ self.dropout = Dropout(config.hidden_dropout_prob)
465
+ self.add = Add()
466
+
467
+ def forward(self, hidden_states, input_tensor):
468
+ hidden_states = self.dense(hidden_states)
469
+ hidden_states = self.dropout(hidden_states)
470
+ add = self.add([hidden_states, input_tensor])
471
+ hidden_states = self.LayerNorm(add)
472
+ return hidden_states
473
+
474
+ def relprop(self, cam, **kwargs):
475
+ # print("in", cam.sum())
476
+ cam = self.LayerNorm.relprop(cam, **kwargs)
477
+ #print(cam.sum())
478
+ # [hidden_states, input_tensor]
479
+ (cam1, cam2)= self.add.relprop(cam, **kwargs)
480
+ # print("add", cam1.sum(), cam2.sum(), cam1.sum() + cam2.sum())
481
+ cam1 = self.dropout.relprop(cam1, **kwargs)
482
+ #print(cam1.sum())
483
+ cam1 = self.dense.relprop(cam1, **kwargs)
484
+ # print("dense", cam1.sum())
485
+
486
+ # print("out", cam1.sum() + cam2.sum(), cam1.sum(), cam2.sum())
487
+ return (cam1, cam2)
488
+
489
+
490
+ class BertLayer(nn.Module):
491
+ def __init__(self, config):
492
+ super().__init__()
493
+ self.attention = BertAttention(config)
494
+ self.intermediate = BertIntermediate(config)
495
+ self.output = BertOutput(config)
496
+ self.clone = Clone()
497
+
498
+ def forward(
499
+ self,
500
+ hidden_states,
501
+ attention_mask=None,
502
+ head_mask=None,
503
+ output_attentions=False,
504
+ ):
505
+ self_attention_outputs = self.attention(
506
+ hidden_states,
507
+ attention_mask,
508
+ head_mask,
509
+ output_attentions=output_attentions,
510
+ )
511
+ attention_output = self_attention_outputs[0]
512
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
513
+
514
+ ao1, ao2 = self.clone(attention_output, 2)
515
+ intermediate_output = self.intermediate(ao1)
516
+ layer_output = self.output(intermediate_output, ao2)
517
+
518
+ outputs = (layer_output,) + outputs
519
+ return outputs
520
+
521
+ def relprop(self, cam, **kwargs):
522
+ (cam1, cam2) = self.output.relprop(cam, **kwargs)
523
+ # print("output", cam1.sum(), cam2.sum(), cam1.sum() + cam2.sum())
524
+ cam1 = self.intermediate.relprop(cam1, **kwargs)
525
+ # print("intermediate", cam1.sum())
526
+ cam = self.clone.relprop((cam1, cam2), **kwargs)
527
+ # print("clone", cam.sum())
528
+ cam = self.attention.relprop(cam, **kwargs)
529
+ # print("attention", cam.sum())
530
+ return cam
531
+
532
+
533
+ class BertModel(BertPreTrainedModel):
534
+ def __init__(self, config):
535
+ super().__init__(config)
536
+ self.config = config
537
+
538
+ self.embeddings = BertEmbeddings(config)
539
+ self.encoder = BertEncoder(config)
540
+ self.pooler = BertPooler(config)
541
+
542
+ self.init_weights()
543
+
544
+ def get_input_embeddings(self):
545
+ return self.embeddings.word_embeddings
546
+
547
+ def set_input_embeddings(self, value):
548
+ self.embeddings.word_embeddings = value
549
+
550
+ def forward(
551
+ self,
552
+ input_ids=None,
553
+ attention_mask=None,
554
+ token_type_ids=None,
555
+ position_ids=None,
556
+ head_mask=None,
557
+ inputs_embeds=None,
558
+ encoder_hidden_states=None,
559
+ encoder_attention_mask=None,
560
+ output_attentions=None,
561
+ output_hidden_states=None,
562
+ return_dict=None,
563
+ ):
564
+ r"""
565
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
566
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
567
+ if the model is configured as a decoder.
568
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
569
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask
570
+ is used in the cross-attention if the model is configured as a decoder.
571
+ Mask values selected in ``[0, 1]``:
572
+ ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
573
+ """
574
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
575
+ output_hidden_states = (
576
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
577
+ )
578
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
579
+
580
+ if input_ids is not None and inputs_embeds is not None:
581
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
582
+ elif input_ids is not None:
583
+ input_shape = input_ids.size()
584
+ elif inputs_embeds is not None:
585
+ input_shape = inputs_embeds.size()[:-1]
586
+ else:
587
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
588
+
589
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
590
+
591
+ if attention_mask is None:
592
+ attention_mask = torch.ones(input_shape, device=device)
593
+ if token_type_ids is None:
594
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
595
+
596
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
597
+ # ourselves in which case we just need to make it broadcastable to all heads.
598
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
599
+
600
+ # If a 2D or 3D attention mask is provided for the cross-attention
601
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
602
+ if self.config.is_decoder and encoder_hidden_states is not None:
603
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
604
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
605
+ if encoder_attention_mask is None:
606
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
607
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
608
+ else:
609
+ encoder_extended_attention_mask = None
610
+
611
+ # Prepare head mask if needed
612
+ # 1.0 in head_mask indicate we keep the head
613
+ # attention_probs has shape bsz x n_heads x N x N
614
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
615
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
616
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
617
+
618
+ embedding_output = self.embeddings(
619
+ input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
620
+ )
621
+
622
+ encoder_outputs = self.encoder(
623
+ embedding_output,
624
+ attention_mask=extended_attention_mask,
625
+ head_mask=head_mask,
626
+ encoder_hidden_states=encoder_hidden_states,
627
+ encoder_attention_mask=encoder_extended_attention_mask,
628
+ output_attentions=output_attentions,
629
+ output_hidden_states=output_hidden_states,
630
+ return_dict=return_dict,
631
+ )
632
+ sequence_output = encoder_outputs[0]
633
+ pooled_output = self.pooler(sequence_output)
634
+
635
+ if not return_dict:
636
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
637
+
638
+ return BaseModelOutputWithPooling(
639
+ last_hidden_state=sequence_output,
640
+ pooler_output=pooled_output,
641
+ hidden_states=encoder_outputs.hidden_states,
642
+ attentions=encoder_outputs.attentions,
643
+ )
644
+
645
+ def relprop(self, cam, **kwargs):
646
+ cam = self.pooler.relprop(cam, **kwargs)
647
+ # print("111111111111",cam.sum())
648
+ cam = self.encoder.relprop(cam, **kwargs)
649
+ # print("222222222222222", cam.sum())
650
+ # print("conservation: ", cam.sum())
651
+ return cam
652
+
653
+
654
+ if __name__ == '__main__':
655
+ class Config:
656
+ def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob):
657
+ self.hidden_size = hidden_size
658
+ self.num_attention_heads = num_attention_heads
659
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
660
+
661
+ model = BertSelfAttention(Config(1024, 4, 0.1))
662
+ x = torch.rand(2, 20, 1024)
663
+ x.requires_grad_()
664
+
665
+ model.eval()
666
+
667
+ y = model.forward(x)
668
+
669
+ relprop = model.relprop(torch.rand(2, 20, 1024), (torch.rand(2, 20, 1024),))
670
+
671
+ print(relprop[1][0].shape)
BERT_explainability/BERTalt.py ADDED
@@ -0,0 +1,551 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+
3
+ import torch
4
+ from torch import nn
5
+ import torch.nn.functional as F
6
+ import math
7
+ from BERT_explainability.modules.layers_ours import *
8
+
9
+ import transformers
10
+
11
+ from transformers import BertConfig
12
+ from transformers.modeling_outputs import BaseModelOutputWithPooling, BaseModelOutput
13
+ from transformers import (
14
+ BertPreTrainedModel,
15
+ PreTrainedModel,
16
+ )
17
+
18
+
19
+ ACT2FN = {
20
+ "relu": ReLU,
21
+ "tanh": Tanh,
22
+ "gelu": GELU,
23
+ }
24
+
25
+
26
+ def get_activation(activation_string):
27
+ if activation_string in ACT2FN:
28
+ return ACT2FN[activation_string]
29
+ else:
30
+ raise KeyError("function {} not found in ACT2FN mapping {}".format(activation_string, list(ACT2FN.keys())))
31
+
32
+ def compute_rollout_attention(all_layer_matrices, start_layer=0):
33
+ # adding residual consideration
34
+ num_tokens = all_layer_matrices[0].shape[1]
35
+ batch_size = all_layer_matrices[0].shape[0]
36
+ eye = torch.eye(num_tokens).expand(batch_size, num_tokens, num_tokens).to(all_layer_matrices[0].device)
37
+ all_layer_matrices = [all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))]
38
+ all_layer_matrices = [all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True)
39
+ for i in range(len(all_layer_matrices))]
40
+ joint_attention = all_layer_matrices[start_layer]
41
+ for i in range(start_layer+1, len(all_layer_matrices)):
42
+ joint_attention = all_layer_matrices[i].bmm(joint_attention)
43
+ return joint_attention
44
+
45
+ class RPBertEmbeddings(BertEmbeddings):
46
+ def __init__(self, config):
47
+ super().__init__()
48
+
49
+ self.add1 = Add()
50
+ self.add2 = Add()
51
+
52
+ def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
53
+ if input_ids is not None:
54
+ input_shape = input_ids.size()
55
+ else:
56
+ input_shape = inputs_embeds.size()[:-1]
57
+
58
+ seq_length = input_shape[1]
59
+
60
+ if position_ids is None:
61
+ position_ids = self.position_ids[:, :seq_length]
62
+
63
+ if token_type_ids is None:
64
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
65
+
66
+ if inputs_embeds is None:
67
+ inputs_embeds = self.word_embeddings(input_ids)
68
+ position_embeddings = self.position_embeddings(position_ids)
69
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
70
+
71
+ # embeddings = inputs_embeds + position_embeddings + token_type_embeddings
72
+ embeddings = self.add1([token_type_embeddings, position_embeddings])
73
+ embeddings = self.add2([embeddings, inputs_embeds])
74
+ embeddings = self.LayerNorm(embeddings)
75
+ embeddings = self.dropout(embeddings)
76
+ return embeddings
77
+
78
+ def relprop(self, cam, **kwargs):
79
+ cam = self.dropout.relprop(cam, **kwargs)
80
+ cam = self.LayerNorm.relprop(cam, **kwargs)
81
+
82
+ # [inputs_embeds, position_embeddings, token_type_embeddings]
83
+ (cam) = self.add2.relprop(cam, **kwargs)
84
+
85
+ return cam
86
+
87
+ class RPBertEncoder(transformers.modeling_bert.BertEncoder):
88
+ def __init__(self, config):
89
+ super().__init__()
90
+ self.config = config
91
+ self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
92
+
93
+ def relprop(self, cam, **kwargs):
94
+ # assuming output_hidden_states is False
95
+ for layer_module in reversed(self.layer):
96
+ cam = layer_module.relprop(cam, **kwargs)
97
+ return cam
98
+
99
+
100
+ # not adding relprop since this is only pooling at the end of the network, does not impact tokens importance
101
+ class RPBertPooler(transformers.modeling_bert.BertPooler):
102
+ def __init__(self, config):
103
+ super().__init__()
104
+ self.pool = IndexSelect()
105
+
106
+ def forward(self, hidden_states):
107
+ # We "pool" the model by simply taking the hidden state corresponding
108
+ # to the first token.
109
+ self._seq_size = hidden_states.shape[1]
110
+
111
+ # first_token_tensor = hidden_states[:, 0]
112
+ first_token_tensor = self.pool(hidden_states, 1, torch.tensor(0, device=hidden_states.device))
113
+ first_token_tensor = first_token_tensor.squeeze(1)
114
+ pooled_output = self.dense(first_token_tensor)
115
+ pooled_output = self.activation(pooled_output)
116
+ return pooled_output
117
+
118
+ def relprop(self, cam, **kwargs):
119
+ cam = self.activation.relprop(cam, **kwargs)
120
+ #print(cam.sum())
121
+ cam = self.dense.relprop(cam, **kwargs)
122
+ #print(cam.sum())
123
+ cam = cam.unsqueeze(1)
124
+ cam = self.pool.relprop(cam, **kwargs)
125
+ #print(cam.sum())
126
+
127
+ return cam
128
+
129
+ class BertAttention(transformers.modeling_bert.BertAttention):
130
+ def __init__(self, config):
131
+ super().__init__()
132
+ self.clone = Clone()
133
+
134
+ def forward(
135
+ self,
136
+ hidden_states,
137
+ attention_mask=None,
138
+ head_mask=None,
139
+ encoder_hidden_states=None,
140
+ encoder_attention_mask=None,
141
+ output_attentions=False,
142
+ ):
143
+ h1, h2 = self.clone(hidden_states, 2)
144
+ self_outputs = self.self(
145
+ h1,
146
+ attention_mask,
147
+ head_mask,
148
+ encoder_hidden_states,
149
+ encoder_attention_mask,
150
+ output_attentions,
151
+ )
152
+ attention_output = self.output(self_outputs[0], h2)
153
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
154
+ return outputs
155
+
156
+ def relprop(self, cam, **kwargs):
157
+ # assuming that we don't ouput the attentions (outputs = (attention_output,)), self_outputs=(context_layer,)
158
+ (cam1, cam2) = self.output.relprop(cam, **kwargs)
159
+ #print(cam1.sum(), cam2.sum(), (cam1 + cam2).sum())
160
+ cam1 = self.self.relprop(cam1, **kwargs)
161
+ #print(cam1.sum(), cam2.sum(), (cam1 + cam2).sum())
162
+
163
+ return self.clone.relprop((cam1, cam2), **kwargs)
164
+
165
+ class BertSelfAttention(transformers.modeling_bert.BertSelfAttention):
166
+ def __init__(self, config):
167
+ super().__init__()
168
+
169
+ self.matmul1 = MatMul()
170
+ self.matmul2 = MatMul()
171
+ self.softmax = Softmax(dim=-1)
172
+ self.add = Add()
173
+ self.mul = Mul()
174
+ self.head_mask = None
175
+ self.attention_mask = None
176
+ self.clone = Clone()
177
+
178
+ self.attn_cam = None
179
+ self.attn = None
180
+ self.attn_gradients = None
181
+
182
+ def get_attn(self):
183
+ return self.attn
184
+
185
+ def save_attn(self, attn):
186
+ self.attn = attn
187
+
188
+ def save_attn_cam(self, cam):
189
+ self.attn_cam = cam
190
+
191
+ def get_attn_cam(self):
192
+ return self.attn_cam
193
+
194
+ def save_attn_gradients(self, attn_gradients):
195
+ self.attn_gradients = attn_gradients
196
+
197
+ def get_attn_gradients(self):
198
+ return self.attn_gradients
199
+
200
+ def transpose_for_scores_relprop(self, x):
201
+ return x.permute(0, 2, 1, 3).flatten(2)
202
+
203
+ def forward(
204
+ self,
205
+ hidden_states,
206
+ attention_mask=None,
207
+ head_mask=None,
208
+ encoder_hidden_states=None,
209
+ encoder_attention_mask=None,
210
+ output_attentions=False,
211
+ ):
212
+ self.head_mask = head_mask
213
+ self.attention_mask = attention_mask
214
+
215
+ h1, h2, h3 = self.clone(hidden_states, 3)
216
+ mixed_query_layer = self.query(h1)
217
+
218
+ # If this is instantiated as a cross-attention module, the keys
219
+ # and values come from an encoder; the attention mask needs to be
220
+ # such that the encoder's padding tokens are not attended to.
221
+ if encoder_hidden_states is not None:
222
+ mixed_key_layer = self.key(encoder_hidden_states)
223
+ mixed_value_layer = self.value(encoder_hidden_states)
224
+ attention_mask = encoder_attention_mask
225
+ else:
226
+ mixed_key_layer = self.key(h2)
227
+ mixed_value_layer = self.value(h3)
228
+
229
+ query_layer = self.transpose_for_scores(mixed_query_layer)
230
+ key_layer = self.transpose_for_scores(mixed_key_layer)
231
+ value_layer = self.transpose_for_scores(mixed_value_layer)
232
+
233
+ # Take the dot product between "query" and "key" to get the raw attention scores.
234
+ attention_scores = self.matmul1([query_layer, key_layer.transpose(-1, -2)])
235
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
236
+ if attention_mask is not None:
237
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
238
+ attention_scores = self.add([attention_scores, attention_mask])
239
+
240
+ # Normalize the attention scores to probabilities.
241
+ attention_probs = self.softmax(attention_scores)
242
+
243
+ self.save_attn(attention_probs)
244
+ attention_probs.register_hook(self.save_attn_gradients)
245
+
246
+ # This is actually dropping out entire tokens to attend to, which might
247
+ # seem a bit unusual, but is taken from the original Transformer paper.
248
+ attention_probs = self.dropout(attention_probs)
249
+
250
+ # Mask heads if we want to
251
+ if head_mask is not None:
252
+ attention_probs = attention_probs * head_mask
253
+
254
+ context_layer = self.matmul2([attention_probs, value_layer])
255
+
256
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
257
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
258
+ context_layer = context_layer.view(*new_context_layer_shape)
259
+
260
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
261
+ return outputs
262
+
263
+ def relprop(self, cam, **kwargs):
264
+ # Assume output_attentions == False
265
+ cam = self.transpose_for_scores(cam)
266
+
267
+ # [attention_probs, value_layer]
268
+ (cam1, cam2) = self.matmul2.relprop(cam, **kwargs)
269
+ cam1 /= 2
270
+ cam2 /= 2
271
+ if self.head_mask is not None:
272
+ # [attention_probs, head_mask]
273
+ (cam1, _)= self.mul.relprop(cam1, **kwargs)
274
+
275
+
276
+ self.save_attn_cam(cam1)
277
+
278
+ cam1 = self.dropout.relprop(cam1, **kwargs)
279
+
280
+ cam1 = self.softmax.relprop(cam1, **kwargs)
281
+
282
+ if self.attention_mask is not None:
283
+ # [attention_scores, attention_mask]
284
+ (cam1, _) = self.add.relprop(cam1, **kwargs)
285
+
286
+ # [query_layer, key_layer.transpose(-1, -2)]
287
+ (cam1_1, cam1_2) = self.matmul1.relprop(cam1, **kwargs)
288
+ cam1_1 /= 2
289
+ cam1_2 /= 2
290
+
291
+ # query
292
+ cam1_1 = self.transpose_for_scores_relprop(cam1_1)
293
+ cam1_1 = self.query.relprop(cam1_1, **kwargs)
294
+
295
+ # key
296
+ cam1_2 = self.transpose_for_scores_relprop(cam1_2.transpose(-1, -2))
297
+ cam1_2 = self.key.relprop(cam1_2, **kwargs)
298
+
299
+ # value
300
+ cam2 = self.transpose_for_scores_relprop(cam2)
301
+ cam2 = self.value.relprop(cam2, **kwargs)
302
+
303
+ cam = self.clone.relprop((cam1_1, cam1_2, cam2), **kwargs)
304
+
305
+ return cam
306
+
307
+
308
+ class BertSelfOutput(transformers.modeling_bert.BertSelfOutput):
309
+ def __init__(self, config):
310
+ super().__init__()
311
+ self.add = Add()
312
+
313
+ def forward(self, hidden_states, input_tensor):
314
+ hidden_states = self.dense(hidden_states)
315
+ hidden_states = self.dropout(hidden_states)
316
+ add = self.add([hidden_states, input_tensor])
317
+ hidden_states = self.LayerNorm(add)
318
+ return hidden_states
319
+
320
+ def relprop(self, cam, **kwargs):
321
+ cam = self.LayerNorm.relprop(cam, **kwargs)
322
+ # [hidden_states, input_tensor]
323
+ (cam1, cam2) = self.add.relprop(cam, **kwargs)
324
+ cam1 = self.dropout.relprop(cam1, **kwargs)
325
+ cam1 = self.dense.relprop(cam1, **kwargs)
326
+
327
+ return (cam1, cam2)
328
+
329
+
330
+ class BertIntermediate(transformers.modeling_bert.BertIntermediate):
331
+ def relprop(self, cam, **kwargs):
332
+ cam = self.intermediate_act_fn.relprop(cam, **kwargs) # FIXME only ReLU
333
+ #print(cam.sum())
334
+ cam = self.dense.relprop(cam, **kwargs)
335
+ #print(cam.sum())
336
+ return cam
337
+
338
+
339
+ class BertOutput(transformers.modeling_bert.BertOutput):
340
+ def __init__(self, config):
341
+ super().__init__()
342
+ self.add = Add()
343
+
344
+ def forward(self, hidden_states, input_tensor):
345
+ hidden_states = self.dense(hidden_states)
346
+ hidden_states = self.dropout(hidden_states)
347
+ add = self.add([hidden_states, input_tensor])
348
+ hidden_states = self.LayerNorm(add)
349
+ return hidden_states
350
+
351
+ def relprop(self, cam, **kwargs):
352
+ # print("in", cam.sum())
353
+ cam = self.LayerNorm.relprop(cam, **kwargs)
354
+ #print(cam.sum())
355
+ # [hidden_states, input_tensor]
356
+ (cam1, cam2)= self.add.relprop(cam, **kwargs)
357
+ # print("add", cam1.sum(), cam2.sum(), cam1.sum() + cam2.sum())
358
+ cam1 = self.dropout.relprop(cam1, **kwargs)
359
+ #print(cam1.sum())
360
+ cam1 = self.dense.relprop(cam1, **kwargs)
361
+ # print("dense", cam1.sum())
362
+
363
+ # print("out", cam1.sum() + cam2.sum(), cam1.sum(), cam2.sum())
364
+ return (cam1, cam2)
365
+
366
+
367
+ class RPBertLayer(nn.Module):
368
+ def __init__(self, config):
369
+ super().__init__()
370
+ self.attention = BertAttention(config)
371
+ self.intermediate = BertIntermediate(config)
372
+ self.output = BertOutput(config)
373
+ self.clone = Clone()
374
+
375
+ def forward(
376
+ self,
377
+ hidden_states,
378
+ attention_mask=None,
379
+ head_mask=None,
380
+ output_attentions=False,
381
+ ):
382
+ self_attention_outputs = self.attention(
383
+ hidden_states,
384
+ attention_mask,
385
+ head_mask,
386
+ output_attentions=output_attentions,
387
+ )
388
+ attention_output = self_attention_outputs[0]
389
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
390
+
391
+ ao1, ao2 = self.clone(attention_output, 2)
392
+ intermediate_output = self.intermediate(ao1)
393
+ layer_output = self.output(intermediate_output, ao2)
394
+
395
+ outputs = (layer_output,) + outputs
396
+ return outputs
397
+
398
+ def relprop(self, cam, **kwargs):
399
+ (cam1, cam2) = self.output.relprop(cam, **kwargs)
400
+ # print("output", cam1.sum(), cam2.sum(), cam1.sum() + cam2.sum())
401
+ cam1 = self.intermediate.relprop(cam1, **kwargs)
402
+ # print("intermediate", cam1.sum())
403
+ cam = self.clone.relprop((cam1, cam2), **kwargs)
404
+ # print("clone", cam.sum())
405
+ cam = self.attention.relprop(cam, **kwargs)
406
+ # print("attention", cam.sum())
407
+ return cam
408
+
409
+
410
+ class BertModel(BertPreTrainedModel):
411
+ def __init__(self, config):
412
+ super().__init__(config)
413
+ self.config = config
414
+
415
+ self.embeddings = BertEmbeddings(config)
416
+ self.encoder = BertEncoder(config)
417
+ self.pooler = BertPooler(config)
418
+
419
+ self.init_weights()
420
+
421
+ def get_input_embeddings(self):
422
+ return self.embeddings.word_embeddings
423
+
424
+ def set_input_embeddings(self, value):
425
+ self.embeddings.word_embeddings = value
426
+
427
+ def forward(
428
+ self,
429
+ input_ids=None,
430
+ attention_mask=None,
431
+ token_type_ids=None,
432
+ position_ids=None,
433
+ head_mask=None,
434
+ inputs_embeds=None,
435
+ encoder_hidden_states=None,
436
+ encoder_attention_mask=None,
437
+ output_attentions=None,
438
+ output_hidden_states=None,
439
+ return_dict=None,
440
+ ):
441
+ r"""
442
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
443
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
444
+ if the model is configured as a decoder.
445
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
446
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask
447
+ is used in the cross-attention if the model is configured as a decoder.
448
+ Mask values selected in ``[0, 1]``:
449
+ ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
450
+ """
451
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
452
+ output_hidden_states = (
453
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
454
+ )
455
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
456
+
457
+ if input_ids is not None and inputs_embeds is not None:
458
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
459
+ elif input_ids is not None:
460
+ input_shape = input_ids.size()
461
+ elif inputs_embeds is not None:
462
+ input_shape = inputs_embeds.size()[:-1]
463
+ else:
464
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
465
+
466
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
467
+
468
+ if attention_mask is None:
469
+ attention_mask = torch.ones(input_shape, device=device)
470
+ if token_type_ids is None:
471
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
472
+
473
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
474
+ # ourselves in which case we just need to make it broadcastable to all heads.
475
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
476
+
477
+ # If a 2D or 3D attention mask is provided for the cross-attention
478
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
479
+ if self.config.is_decoder and encoder_hidden_states is not None:
480
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
481
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
482
+ if encoder_attention_mask is None:
483
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
484
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
485
+ else:
486
+ encoder_extended_attention_mask = None
487
+
488
+ # Prepare head mask if needed
489
+ # 1.0 in head_mask indicate we keep the head
490
+ # attention_probs has shape bsz x n_heads x N x N
491
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
492
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
493
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
494
+
495
+ embedding_output = self.embeddings(
496
+ input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
497
+ )
498
+
499
+ encoder_outputs = self.encoder(
500
+ embedding_output,
501
+ attention_mask=extended_attention_mask,
502
+ head_mask=head_mask,
503
+ encoder_hidden_states=encoder_hidden_states,
504
+ encoder_attention_mask=encoder_extended_attention_mask,
505
+ output_attentions=output_attentions,
506
+ output_hidden_states=output_hidden_states,
507
+ return_dict=return_dict,
508
+ )
509
+ sequence_output = encoder_outputs[0]
510
+ pooled_output = self.pooler(sequence_output)
511
+
512
+ if not return_dict:
513
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
514
+
515
+ return BaseModelOutputWithPooling(
516
+ last_hidden_state=sequence_output,
517
+ pooler_output=pooled_output,
518
+ hidden_states=encoder_outputs.hidden_states,
519
+ attentions=encoder_outputs.attentions,
520
+ )
521
+
522
+ def relprop(self, cam, **kwargs):
523
+ cam = self.pooler.relprop(cam, **kwargs)
524
+ # print("111111111111",cam.sum())
525
+ cam = self.encoder.relprop(cam, **kwargs)
526
+ # print("222222222222222", cam.sum())
527
+ # print("conservation: ", cam.sum())
528
+ return cam
529
+
530
+
531
+ transformers.modeling_bert.BertEmbeddings = RPBertEmbeddings
532
+ transformers.modeling_bert.BertEncoder = RPBertEncoder
533
+
534
+ if __name__ == '__main__':
535
+ class Config:
536
+ def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob):
537
+ self.hidden_size = hidden_size
538
+ self.num_attention_heads = num_attention_heads
539
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
540
+
541
+ model = BertSelfAttention(Config(1024, 4, 0.1))
542
+ x = torch.rand(2, 20, 1024)
543
+ x.requires_grad_()
544
+
545
+ model.eval()
546
+
547
+ y = model.forward(x)
548
+
549
+ relprop = model.relprop(torch.rand(2, 20, 1024), (torch.rand(2, 20, 1024),))
550
+
551
+ print(relprop[1][0].shape)
BERT_explainability/BertForSequenceClassification.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertPreTrainedModel
2
+ from transformers.modeling_outputs import SequenceClassifierOutput
3
+ from transformers.utils import logging
4
+ from BERT_explainability.modules.layers_ours import *
5
+ from BERT_explainability.modules.BERT.BERT import BertModel
6
+ from torch.nn import CrossEntropyLoss, MSELoss
7
+ import torch.nn as nn
8
+ from typing import List, Any
9
+ import torch
10
+ from BERT_rationale_benchmark.models.model_utils import PaddedSequence
11
+
12
+
13
+ class BertForSequenceClassification(BertPreTrainedModel):
14
+ def __init__(self, config):
15
+ super().__init__(config)
16
+ self.num_labels = config.num_labels
17
+
18
+ self.bert = BertModel(config)
19
+ self.dropout = Dropout(config.hidden_dropout_prob)
20
+ self.classifier = Linear(config.hidden_size, config.num_labels)
21
+
22
+ self.init_weights()
23
+
24
+ def forward(
25
+ self,
26
+ input_ids=None,
27
+ attention_mask=None,
28
+ token_type_ids=None,
29
+ position_ids=None,
30
+ head_mask=None,
31
+ inputs_embeds=None,
32
+ labels=None,
33
+ output_attentions=None,
34
+ output_hidden_states=None,
35
+ return_dict=None,
36
+ ):
37
+ r"""
38
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
39
+ Labels for computing the sequence classification/regression loss.
40
+ Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
41
+ If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
42
+ If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
43
+ """
44
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
45
+
46
+ outputs = self.bert(
47
+ input_ids,
48
+ attention_mask=attention_mask,
49
+ token_type_ids=token_type_ids,
50
+ position_ids=position_ids,
51
+ head_mask=head_mask,
52
+ inputs_embeds=inputs_embeds,
53
+ output_attentions=output_attentions,
54
+ output_hidden_states=output_hidden_states,
55
+ return_dict=return_dict,
56
+ )
57
+
58
+ pooled_output = outputs[1]
59
+
60
+ pooled_output = self.dropout(pooled_output)
61
+ logits = self.classifier(pooled_output)
62
+
63
+ loss = None
64
+ if labels is not None:
65
+ if self.num_labels == 1:
66
+ # We are doing regression
67
+ loss_fct = MSELoss()
68
+ loss = loss_fct(logits.view(-1), labels.view(-1))
69
+ else:
70
+ loss_fct = CrossEntropyLoss()
71
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
72
+
73
+ if not return_dict:
74
+ output = (logits,) + outputs[2:]
75
+ return ((loss,) + output) if loss is not None else output
76
+
77
+ return SequenceClassifierOutput(
78
+ loss=loss,
79
+ logits=logits,
80
+ hidden_states=outputs.hidden_states,
81
+ attentions=outputs.attentions,
82
+ )
83
+
84
+ def relprop(self, cam=None, **kwargs):
85
+ cam = self.classifier.relprop(cam, **kwargs)
86
+ cam = self.dropout.relprop(cam, **kwargs)
87
+ cam = self.bert.relprop(cam, **kwargs)
88
+ # print("conservation: ", cam.sum())
89
+ return cam
90
+
91
+
92
+ # this is the actual classifier we will be using
93
+ class BertClassifier(nn.Module):
94
+ """Thin wrapper around BertForSequenceClassification"""
95
+
96
+ def __init__(self,
97
+ bert_dir: str,
98
+ pad_token_id: int,
99
+ cls_token_id: int,
100
+ sep_token_id: int,
101
+ num_labels: int,
102
+ max_length: int = 512,
103
+ use_half_precision=True):
104
+ super(BertClassifier, self).__init__()
105
+ bert = BertForSequenceClassification.from_pretrained(bert_dir, num_labels=num_labels)
106
+ if use_half_precision:
107
+ import apex
108
+ bert = bert.half()
109
+ self.bert = bert
110
+ self.pad_token_id = pad_token_id
111
+ self.cls_token_id = cls_token_id
112
+ self.sep_token_id = sep_token_id
113
+ self.max_length = max_length
114
+
115
+ def forward(self,
116
+ query: List[torch.tensor],
117
+ docids: List[Any],
118
+ document_batch: List[torch.tensor]):
119
+ assert len(query) == len(document_batch)
120
+ print(query)
121
+ # note about device management:
122
+ # since distributed training is enabled, the inputs to this module can be on *any* device (preferably cpu, since we wrap and unwrap the module)
123
+ # we want to keep these params on the input device (assuming CPU) for as long as possible for cheap memory access
124
+ target_device = next(self.parameters()).device
125
+ cls_token = torch.tensor([self.cls_token_id]).to(device=document_batch[0].device)
126
+ sep_token = torch.tensor([self.sep_token_id]).to(device=document_batch[0].device)
127
+ input_tensors = []
128
+ position_ids = []
129
+ for q, d in zip(query, document_batch):
130
+ if len(q) + len(d) + 2 > self.max_length:
131
+ d = d[:(self.max_length - len(q) - 2)]
132
+ input_tensors.append(torch.cat([cls_token, q, sep_token, d]))
133
+ position_ids.append(torch.tensor(list(range(0, len(q) + 1)) + list(range(0, len(d) + 1))))
134
+ bert_input = PaddedSequence.autopad(input_tensors, batch_first=True, padding_value=self.pad_token_id,
135
+ device=target_device)
136
+ positions = PaddedSequence.autopad(position_ids, batch_first=True, padding_value=0, device=target_device)
137
+ (classes,) = self.bert(bert_input.data,
138
+ attention_mask=bert_input.mask(on=0.0, off=float('-inf'), device=target_device),
139
+ position_ids=positions.data)
140
+ assert torch.all(classes == classes) # for nans
141
+
142
+ print(input_tensors[0])
143
+ print(self.relprop()[0])
144
+
145
+ return classes
146
+
147
+ def relprop(self, cam=None, **kwargs):
148
+ return self.bert.relprop(cam, **kwargs)
149
+
150
+
151
+ if __name__ == '__main__':
152
+ from transformers import BertTokenizer
153
+ import os
154
+
155
+ class Config:
156
+ def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, num_labels,
157
+ hidden_dropout_prob):
158
+ self.hidden_size = hidden_size
159
+ self.num_attention_heads = num_attention_heads
160
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
161
+ self.num_labels = num_labels
162
+ self.hidden_dropout_prob = hidden_dropout_prob
163
+
164
+
165
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
166
+ x = tokenizer.encode_plus("In this movie the acting is great. The movie is perfect! [sep]",
167
+ add_special_tokens=True,
168
+ max_length=512,
169
+ return_token_type_ids=False,
170
+ return_attention_mask=True,
171
+ pad_to_max_length=True,
172
+ return_tensors='pt',
173
+ truncation=True)
174
+
175
+ print(x['input_ids'])
176
+
177
+ model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
178
+ model_save_file = os.path.join('./BERT_explainability/output_bert/movies/classifier/', 'classifier.pt')
179
+ model.load_state_dict(torch.load(model_save_file))
180
+
181
+ # x = torch.randint(100, (2, 20))
182
+ # x = torch.tensor([[101, 2054, 2003, 1996, 15792, 1997, 2023, 3319, 1029, 102,
183
+ # 101, 4079, 102, 101, 6732, 102, 101, 2643, 102, 101,
184
+ # 2038, 102, 101, 1037, 102, 101, 2933, 102, 101, 2005,
185
+ # 102, 101, 2032, 102, 101, 1010, 102, 101, 1037, 102,
186
+ # 101, 3800, 102, 101, 2005, 102, 101, 2010, 102, 101,
187
+ # 2166, 102, 101, 1010, 102, 101, 1998, 102, 101, 2010,
188
+ # 102, 101, 4650, 102, 101, 1010, 102, 101, 2002, 102,
189
+ # 101, 2074, 102, 101, 2515, 102, 101, 1050, 102, 101,
190
+ # 1005, 102, 101, 1056, 102, 101, 2113, 102, 101, 2054,
191
+ # 102, 101, 1012, 102]])
192
+ # x.requires_grad_()
193
+
194
+ model.eval()
195
+
196
+ y = model(x['input_ids'], x['attention_mask'])
197
+ print(y)
198
+
199
+ cam, _ = model.relprop()
200
+
201
+ #print(cam.shape)
202
+
203
+ cam = cam.sum(-1)
204
+ #print(cam)
BERT_explainability/ExplanationGenerator.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import numpy as np
3
+ import torch
4
+ import glob
5
+
6
+ from captum._utils.common import _get_module_from_name
7
+
8
+ # compute rollout between attention layers
9
+ def compute_rollout_attention(all_layer_matrices, start_layer=0):
10
+ # adding residual consideration- code adapted from https://github.com/samiraabnar/attention_flow
11
+ num_tokens = all_layer_matrices[0].shape[1]
12
+ batch_size = all_layer_matrices[0].shape[0]
13
+ eye = torch.eye(num_tokens).expand(batch_size, num_tokens, num_tokens).to(all_layer_matrices[0].device)
14
+ all_layer_matrices = [all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))]
15
+ matrices_aug = [all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True)
16
+ for i in range(len(all_layer_matrices))]
17
+ joint_attention = matrices_aug[start_layer]
18
+ for i in range(start_layer+1, len(matrices_aug)):
19
+ joint_attention = matrices_aug[i].bmm(joint_attention)
20
+ return joint_attention
21
+
22
+ class Generator:
23
+ def __init__(self, model, key="bert.encoder.layer"):
24
+ self.model = model
25
+ self.key = key
26
+ self.model.eval()
27
+
28
+ def forward(self, input_ids, attention_mask):
29
+ return self.model(input_ids, attention_mask)
30
+
31
+ def _calculate_gradients(self, output, index, do_relprop=True):
32
+ if index == None:
33
+ index = np.argmax(output.cpu().data.numpy(), axis=-1)
34
+
35
+ one_hot_vector = (torch.nn.functional
36
+ .one_hot(
37
+ # one_hot requires ints
38
+ torch.tensor(index, dtype=torch.int64),
39
+ num_classes=output.size(-1)
40
+ )
41
+ # but requires_grad_ needs floats
42
+ .to(torch.float)
43
+ ).to(output.device)
44
+
45
+ hot_output = torch.sum(one_hot_vector.clone().requires_grad_(True) * output)
46
+ self.model.zero_grad()
47
+ hot_output.backward(retain_graph=True)
48
+
49
+ if do_relprop:
50
+ return self.model.relprop(one_hot_vector, alpha=1)
51
+
52
+ def generate_LRP(self, input_ids, attention_mask,
53
+ index=None, start_layer=11):
54
+ output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
55
+
56
+ if index == None:
57
+ index = np.argmax(output.cpu().data.numpy(), axis=-1)
58
+
59
+ self._calculate_gradients(output, index)
60
+
61
+ cams = []
62
+ blocks = _get_module_from_name(self.model, self.key)
63
+ for blk in blocks:
64
+ grad = blk.attention.self.get_attn_gradients()
65
+ cam = blk.attention.self.get_attn_cam()
66
+ cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
67
+ grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
68
+ cam = grad * cam
69
+ cam = cam.clamp(min=0).mean(dim=0)
70
+ cams.append(cam.unsqueeze(0))
71
+ rollout = compute_rollout_attention(cams, start_layer=start_layer)
72
+ rollout[:, 0, 0] = rollout[:, 0].min()
73
+ return rollout[:, 0]
74
+
75
+
76
+ def generate_LRP_last_layer(self, input_ids, attention_mask,
77
+ index=None):
78
+ output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
79
+ if index == None:
80
+ index = np.argmax(output.cpu().data.numpy(), axis=-1)
81
+
82
+ self._calculate_gradients(output, index)
83
+
84
+ cam = _get_module_from_name(self.model, self.key)[-1].attention.self.get_attn_cam()[0]
85
+ cam = cam.clamp(min=0).mean(dim=0).unsqueeze(0)
86
+ cam[:, 0, 0] = 0
87
+ return cam[:, 0]
88
+
89
+ def generate_full_lrp(self, input_ids, attention_mask,
90
+ index=None):
91
+ output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
92
+
93
+ if index == None:
94
+ index = np.argmax(output.cpu().data.numpy(), axis=-1)
95
+
96
+ cam = self._calculate_gradients(output, index)
97
+ cam = cam.sum(dim=2)
98
+ cam[:, 0] = 0
99
+ return cam
100
+
101
+ def generate_attn_last_layer(self, input_ids, attention_mask,
102
+ index=None):
103
+ output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
104
+ cam = _get_module_from_name(self.model, self.key)[-1].attention.self.get_attn()[0]
105
+ cam = cam.mean(dim=0).unsqueeze(0)
106
+ cam[:, 0, 0] = 0
107
+ return cam[:, 0]
108
+
109
+ def generate_rollout(self, input_ids, attention_mask, start_layer=0, index=None):
110
+ self.model.zero_grad()
111
+ output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
112
+ blocks = _get_module_from_name(self.model, self.key)
113
+ all_layer_attentions = []
114
+ for blk in blocks:
115
+ attn_heads = blk.attention.self.get_attn()
116
+ avg_heads = (attn_heads.sum(dim=1) / attn_heads.shape[1]).detach()
117
+ all_layer_attentions.append(avg_heads)
118
+ rollout = compute_rollout_attention(all_layer_attentions, start_layer=start_layer)
119
+ rollout[:, 0, 0] = 0
120
+ return rollout[:, 0]
121
+
122
+ def generate_attn_gradcam(self, input_ids, attention_mask, index=None):
123
+ output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
124
+
125
+ if index == None:
126
+ index = np.argmax(output.cpu().data.numpy(), axis=-1)
127
+
128
+ self._calculate_gradients(output, index)
129
+
130
+ cam = _get_module_from_name(self.model, self.key)[-1].attention.self.get_attn()
131
+ grad = _get_module_from_name(self.model, self.key)[-1].attention.self.get_attn_gradients()
132
+
133
+ cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
134
+ grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
135
+ grad = grad.mean(dim=[1, 2], keepdim=True)
136
+ cam = (cam * grad).mean(0).clamp(min=0).unsqueeze(0)
137
+ cam = (cam - cam.min()) / (cam.max() - cam.min())
138
+ cam[:, 0, 0] = 0
139
+ return cam[:, 0]
140
+
141
+ def generate_rollout_attn_gradcam(self, input_ids, attention_mask, index=None, start_layer=0):
142
+ # rule 5 from paper
143
+ def avg_heads(cam, grad):
144
+ return (grad * cam).clamp(min=0).mean(dim=-3)
145
+
146
+ # rule 6 from paper
147
+ def apply_self_attention_rules(R_ss, cam_ss):
148
+ return torch.matmul(cam_ss, R_ss)
149
+
150
+ output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
151
+ blocks = _get_module_from_name(self.model, self.key)
152
+
153
+ num_tokens = input_ids.size(-1)
154
+ R = torch.eye(num_tokens).expand(output.size(0), -1, -1).clone().to(output.device)
155
+
156
+ for i, blk in enumerate(model.roberta.encoder.layer):
157
+ if i < start_layer:
158
+ continue
159
+ grad = blk.attention.self.get_attn_gradients().detach()
160
+ cam = blk.attention.self.get_attn().detach()
161
+ cam = avg_heads(cam, grad)
162
+ joint = apply_self_attention_rules(R, cam)
163
+ R += joint
164
+ return R[:, 0, 1:-1]
165
+
BERT_explainability/NewExplanationGenerator.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import numpy as np
3
+ import torch
4
+ import glob
5
+
6
+ from captum._utils.common import _get_module_from_name
7
+
8
+ # compute rollout between attention layers
9
+ def compute_rollout_attention(all_layer_matrices, start_layer=0):
10
+ # adding residual consideration- code adapted from https://github.com/samiraabnar/attention_flow
11
+ num_tokens = all_layer_matrices[0].shape[1]
12
+ batch_size = all_layer_matrices[0].shape[0]
13
+ eye = torch.eye(num_tokens).expand(batch_size, num_tokens, num_tokens).to(all_layer_matrices[0].device)
14
+ all_layer_matrices = [all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))]
15
+ matrices_aug = [all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True)
16
+ for i in range(len(all_layer_matrices))]
17
+ joint_attention = matrices_aug[start_layer]
18
+ for i in range(start_layer+1, len(matrices_aug)):
19
+ joint_attention = matrices_aug[i].bmm(joint_attention)
20
+ return joint_attention
21
+
22
+ class Generator:
23
+ def __init__(self, model, key="bert.encoder.layer"):
24
+ self.model = model
25
+ self.key = key
26
+ self.model.eval()
27
+
28
+ def forward(self, input_ids, attention_mask):
29
+ return self.model(input_ids, attention_mask)
30
+
31
+ def _build_one_hot(self, output, index):
32
+ if index == None:
33
+ index = np.argmax(output.cpu().data.numpy(), axis=-1)
34
+
35
+ one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
36
+ one_hot[0, index] = 1
37
+ one_hot_vector = one_hot
38
+ one_hot = torch.from_numpy(one_hot).requires_grad_(True).to(output.device)
39
+ one_hot = torch.sum(one_hot * output)
40
+
41
+ return one_hot, one_hot_vector
42
+
43
+ def generate_LRP(self, input_ids, attention_mask,
44
+ index=None, start_layer=11):
45
+ output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
46
+ kwargs = {"alpha": 1}
47
+
48
+ one_hot, one_hot_vector = self._build_one_hot(output, index)
49
+ self.model.zero_grad()
50
+ one_hot.backward(retain_graph=True)
51
+
52
+ self.model.relprop(torch.tensor(one_hot_vector).to(input_ids.device), **kwargs)
53
+
54
+ cams = []
55
+ blocks = _get_module_from_name(self.model, self.key)
56
+ for blk in blocks:
57
+ grad = blk.attention.self.get_attn_gradients()
58
+ cam = blk.attention.self.get_attn_cam()
59
+ cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
60
+ grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
61
+ cam = grad * cam
62
+ cam = cam.clamp(min=0).mean(dim=0)
63
+ cams.append(cam.unsqueeze(0))
64
+ rollout = compute_rollout_attention(cams, start_layer=start_layer)
65
+ rollout[:, 0, 0] = rollout[:, 0].min()
66
+ return rollout[:, 0]
67
+
68
+ def generate_LRP_last_layer(self, input_ids, attention_mask,
69
+ index=None):
70
+ output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
71
+ kwargs = {"alpha": 1}
72
+
73
+ one_hot, one_hot_vector = self._build_one_hot(output, index)
74
+
75
+ self.model.zero_grad()
76
+ one_hot.backward(retain_graph=True)
77
+
78
+ self.model.relprop(torch.tensor(one_hot_vector).to(input_ids.device), **kwargs)
79
+
80
+ cam = _get_module_from_name(self.model, self.key)[-1].attention.self.get_attn_cam()[0]
81
+ cam = cam.clamp(min=0).mean(dim=0).unsqueeze(0)
82
+ cam[:, 0, 0] = 0
83
+ return cam[:, 0]
84
+
85
+ def generate_full_lrp(self, input_ids, attention_mask,
86
+ index=None):
87
+ output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
88
+ kwargs = {"alpha": 1}
89
+
90
+ one_hot, one_hot_vector = self._build_one_hot(output, index)
91
+
92
+ self.model.zero_grad()
93
+ one_hot.backward(retain_graph=True)
94
+
95
+ cam = self.model.relprop(torch.tensor(one_hot_vector).to(input_ids.device), **kwargs)
96
+ cam = cam.sum(dim=2)
97
+ cam[:, 0] = 0
98
+ return cam
99
+
100
+ def generate_attn_last_layer(self, input_ids, attention_mask,
101
+ index=None):
102
+ output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
103
+ cam = _get_module_from_name(self.model, self.key)[-1].attention.self.get_attn()[0]
104
+ cam = cam.mean(dim=0).unsqueeze(0)
105
+ cam[:, 0, 0] = 0
106
+ return cam[:, 0]
107
+
108
+ def generate_rollout(self, input_ids, attention_mask, start_layer=0, index=None):
109
+ self.model.zero_grad()
110
+ output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
111
+ blocks = _get_module_from_name(self.model, self.key)
112
+ all_layer_attentions = []
113
+ for blk in blocks:
114
+ attn_heads = blk.attention.self.get_attn()
115
+ avg_heads = (attn_heads.sum(dim=1) / attn_heads.shape[1]).detach()
116
+ all_layer_attentions.append(avg_heads)
117
+ rollout = compute_rollout_attention(all_layer_attentions, start_layer=start_layer)
118
+ rollout[:, 0, 0] = 0
119
+ return rollout[:, 0]
120
+
121
+ def generate_attn_gradcam(self, input_ids, attention_mask, index=None):
122
+ output = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
123
+ kwargs = {"alpha": 1}
124
+
125
+ if index == None:
126
+ index = np.argmax(output.cpu().data.numpy(), axis=-1)
127
+
128
+ one_hot, one_hot_vector = self._build_one_hot(output, index)
129
+
130
+ self.model.zero_grad()
131
+ one_hot.backward(retain_graph=True)
132
+
133
+ self.model.relprop(torch.tensor(one_hot_vector).to(input_ids.device), **kwargs)
134
+
135
+ cam = _get_module_from_name(self.model, self.key)[-1].attention.self.get_attn()
136
+ grad = _get_module_from_name(self.model, self.key)[-1].attention.self.get_attn_gradients()
137
+
138
+ cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
139
+ grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
140
+ grad = grad.mean(dim=[1, 2], keepdim=True)
141
+ cam = (cam * grad).mean(0).clamp(min=0).unsqueeze(0)
142
+ cam = (cam - cam.min()) / (cam.max() - cam.min())
143
+ cam[:, 0, 0] = 0
144
+ return cam[:, 0]
145
+
BERT_explainability/RobertaForSequenceClassification.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertPreTrainedModel
2
+ from transformers.modeling_outputs import SequenceClassifierOutput
3
+ from transformers.utils import logging
4
+ from BERT_explainability.modules.layers_ours import *
5
+ from BERT_explainability.modules.BERT.BERT import BertModel
6
+ from torch.nn import CrossEntropyLoss, MSELoss
7
+ import torch.nn as nn
8
+ from typing import List, Any
9
+ import torch
10
+ from BERT_rationale_benchmark.models.model_utils import PaddedSequence
11
+
12
+
13
+ class BertForSequenceClassification(BertPreTrainedModel):
14
+ def __init__(self, config):
15
+ super().__init__(config)
16
+ self.num_labels = config.num_labels
17
+
18
+ self.bert = BertModel(config)
19
+ self.dropout = Dropout(config.hidden_dropout_prob)
20
+ self.classifier = Linear(config.hidden_size, config.num_labels)
21
+
22
+ self.init_weights()
23
+
24
+ def forward(
25
+ self,
26
+ input_ids=None,
27
+ attention_mask=None,
28
+ token_type_ids=None,
29
+ position_ids=None,
30
+ head_mask=None,
31
+ inputs_embeds=None,
32
+ labels=None,
33
+ output_attentions=None,
34
+ output_hidden_states=None,
35
+ return_dict=None,
36
+ ):
37
+ r"""
38
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
39
+ Labels for computing the sequence classification/regression loss.
40
+ Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
41
+ If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
42
+ If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
43
+ """
44
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
45
+
46
+ outputs = self.bert(
47
+ input_ids,
48
+ attention_mask=attention_mask,
49
+ token_type_ids=token_type_ids,
50
+ position_ids=position_ids,
51
+ head_mask=head_mask,
52
+ inputs_embeds=inputs_embeds,
53
+ output_attentions=output_attentions,
54
+ output_hidden_states=output_hidden_states,
55
+ return_dict=return_dict,
56
+ )
57
+
58
+ pooled_output = outputs[1]
59
+
60
+ pooled_output = self.dropout(pooled_output)
61
+ logits = self.classifier(pooled_output)
62
+
63
+ loss = None
64
+ if labels is not None:
65
+ if self.num_labels == 1:
66
+ # We are doing regression
67
+ loss_fct = MSELoss()
68
+ loss = loss_fct(logits.view(-1), labels.view(-1))
69
+ else:
70
+ loss_fct = CrossEntropyLoss()
71
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
72
+
73
+ if not return_dict:
74
+ output = (logits,) + outputs[2:]
75
+ return ((loss,) + output) if loss is not None else output
76
+
77
+ return SequenceClassifierOutput(
78
+ loss=loss,
79
+ logits=logits,
80
+ hidden_states=outputs.hidden_states,
81
+ attentions=outputs.attentions,
82
+ )
83
+
84
+ def relprop(self, cam=None, **kwargs):
85
+ cam = self.classifier.relprop(cam, **kwargs)
86
+ cam = self.dropout.relprop(cam, **kwargs)
87
+ cam = self.bert.relprop(cam, **kwargs)
88
+ # print("conservation: ", cam.sum())
89
+ return cam
90
+
91
+
92
+ # this is the actual classifier we will be using
93
+ class BertClassifier(nn.Module):
94
+ """Thin wrapper around BertForSequenceClassification"""
95
+
96
+ def __init__(self,
97
+ bert_dir: str,
98
+ pad_token_id: int,
99
+ cls_token_id: int,
100
+ sep_token_id: int,
101
+ num_labels: int,
102
+ max_length: int = 512,
103
+ use_half_precision=True):
104
+ super(BertClassifier, self).__init__()
105
+ bert = BertForSequenceClassification.from_pretrained(bert_dir, num_labels=num_labels)
106
+ if use_half_precision:
107
+ import apex
108
+ bert = bert.half()
109
+ self.bert = bert
110
+ self.pad_token_id = pad_token_id
111
+ self.cls_token_id = cls_token_id
112
+ self.sep_token_id = sep_token_id
113
+ self.max_length = max_length
114
+
115
+ def forward(self,
116
+ query: List[torch.tensor],
117
+ docids: List[Any],
118
+ document_batch: List[torch.tensor]):
119
+ assert len(query) == len(document_batch)
120
+ print(query)
121
+ # note about device management:
122
+ # since distributed training is enabled, the inputs to this module can be on *any* device (preferably cpu, since we wrap and unwrap the module)
123
+ # we want to keep these params on the input device (assuming CPU) for as long as possible for cheap memory access
124
+ target_device = next(self.parameters()).device
125
+ cls_token = torch.tensor([self.cls_token_id]).to(device=document_batch[0].device)
126
+ sep_token = torch.tensor([self.sep_token_id]).to(device=document_batch[0].device)
127
+ input_tensors = []
128
+ position_ids = []
129
+ for q, d in zip(query, document_batch):
130
+ if len(q) + len(d) + 2 > self.max_length:
131
+ d = d[:(self.max_length - len(q) - 2)]
132
+ input_tensors.append(torch.cat([cls_token, q, sep_token, d]))
133
+ position_ids.append(torch.tensor(list(range(0, len(q) + 1)) + list(range(0, len(d) + 1))))
134
+ bert_input = PaddedSequence.autopad(input_tensors, batch_first=True, padding_value=self.pad_token_id,
135
+ device=target_device)
136
+ positions = PaddedSequence.autopad(position_ids, batch_first=True, padding_value=0, device=target_device)
137
+ (classes,) = self.bert(bert_input.data,
138
+ attention_mask=bert_input.mask(on=0.0, off=float('-inf'), device=target_device),
139
+ position_ids=positions.data)
140
+ assert torch.all(classes == classes) # for nans
141
+
142
+ print(input_tensors[0])
143
+ print(self.relprop()[0])
144
+
145
+ return classes
146
+
147
+ def relprop(self, cam=None, **kwargs):
148
+ return self.bert.relprop(cam, **kwargs)
149
+
150
+
151
+ if __name__ == '__main__':
152
+ from transformers import BertTokenizer
153
+ import os
154
+
155
+ class Config:
156
+ def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, num_labels,
157
+ hidden_dropout_prob):
158
+ self.hidden_size = hidden_size
159
+ self.num_attention_heads = num_attention_heads
160
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
161
+ self.num_labels = num_labels
162
+ self.hidden_dropout_prob = hidden_dropout_prob
163
+
164
+
165
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
166
+ x = tokenizer.encode_plus("In this movie the acting is great. The movie is perfect! [sep]",
167
+ add_special_tokens=True,
168
+ max_length=512,
169
+ return_token_type_ids=False,
170
+ return_attention_mask=True,
171
+ pad_to_max_length=True,
172
+ return_tensors='pt',
173
+ truncation=True)
174
+
175
+ print(x['input_ids'])
176
+
177
+ model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
178
+ model_save_file = os.path.join('./BERT_explainability/output_bert/movies/classifier/', 'classifier.pt')
179
+ model.load_state_dict(torch.load(model_save_file))
180
+
181
+ # x = torch.randint(100, (2, 20))
182
+ # x = torch.tensor([[101, 2054, 2003, 1996, 15792, 1997, 2023, 3319, 1029, 102,
183
+ # 101, 4079, 102, 101, 6732, 102, 101, 2643, 102, 101,
184
+ # 2038, 102, 101, 1037, 102, 101, 2933, 102, 101, 2005,
185
+ # 102, 101, 2032, 102, 101, 1010, 102, 101, 1037, 102,
186
+ # 101, 3800, 102, 101, 2005, 102, 101, 2010, 102, 101,
187
+ # 2166, 102, 101, 1010, 102, 101, 1998, 102, 101, 2010,
188
+ # 102, 101, 4650, 102, 101, 1010, 102, 101, 2002, 102,
189
+ # 101, 2074, 102, 101, 2515, 102, 101, 1050, 102, 101,
190
+ # 1005, 102, 101, 1056, 102, 101, 2113, 102, 101, 2054,
191
+ # 102, 101, 1012, 102]])
192
+ # x.requires_grad_()
193
+
194
+ model.eval()
195
+
196
+ y = model(x['input_ids'], x['attention_mask'])
197
+ print(y)
198
+
199
+ cam, _ = model.relprop()
200
+
201
+ #print(cam.shape)
202
+
203
+ cam = cam.sum(-1)
204
+ #print(cam)
BERT_explainability/roberta2.py ADDED
@@ -0,0 +1,1596 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch RoBERTa model."""
17
+
18
+ import math
19
+ from typing import List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.utils.checkpoint
23
+ from packaging import version
24
+ from torch import nn
25
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
26
+
27
+ from transformers.activations import ACT2FN, gelu
28
+ from transformers.modeling_outputs import (
29
+ BaseModelOutputWithPastAndCrossAttentions,
30
+ BaseModelOutputWithPoolingAndCrossAttentions,
31
+ CausalLMOutputWithCrossAttentions,
32
+ MaskedLMOutput,
33
+ MultipleChoiceModelOutput,
34
+ QuestionAnsweringModelOutput,
35
+ SequenceClassifierOutput,
36
+ TokenClassifierOutput,
37
+ )
38
+ from transformers.modeling_utils import PreTrainedModel
39
+ from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
40
+ from transformers.utils import (
41
+ add_code_sample_docstrings,
42
+ add_start_docstrings,
43
+ add_start_docstrings_to_model_forward,
44
+ logging,
45
+ replace_return_docstrings,
46
+ )
47
+ from transformers.models.roberta.configuration_roberta import RobertaConfig
48
+
49
+
50
+ logger = logging.get_logger(__name__)
51
+
52
+ _CHECKPOINT_FOR_DOC = "roberta-base"
53
+ _CONFIG_FOR_DOC = "RobertaConfig"
54
+ _TOKENIZER_FOR_DOC = "RobertaTokenizer"
55
+
56
+ ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [
57
+ "roberta-base",
58
+ "roberta-large",
59
+ "roberta-large-mnli",
60
+ "distilroberta-base",
61
+ "roberta-base-openai-detector",
62
+ "roberta-large-openai-detector",
63
+ # See all RoBERTa models at https://huggingface.co/models?filter=roberta
64
+ ]
65
+
66
+
67
+ class RobertaEmbeddings(nn.Module):
68
+ """
69
+ Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
70
+ """
71
+
72
+ # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__
73
+ def __init__(self, config):
74
+ super().__init__()
75
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
76
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
77
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
78
+
79
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
80
+ # any TensorFlow checkpoint file
81
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
82
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
83
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
84
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
85
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
86
+ if version.parse(torch.__version__) > version.parse("1.6.0"):
87
+ self.register_buffer(
88
+ "token_type_ids",
89
+ torch.zeros(self.position_ids.size(), dtype=torch.long),
90
+ persistent=False,
91
+ )
92
+
93
+ # End copy
94
+ self.padding_idx = config.pad_token_id
95
+ self.position_embeddings = nn.Embedding(
96
+ config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
97
+ )
98
+
99
+ def forward(
100
+ self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
101
+ ):
102
+ if position_ids is None:
103
+ if input_ids is not None:
104
+ # Create the position ids from the input token ids. Any padded tokens remain padded.
105
+ position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
106
+ else:
107
+ position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
108
+
109
+ if input_ids is not None:
110
+ input_shape = input_ids.size()
111
+ else:
112
+ input_shape = inputs_embeds.size()[:-1]
113
+
114
+ seq_length = input_shape[1]
115
+
116
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
117
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
118
+ # issue #5664
119
+ if token_type_ids is None:
120
+ if hasattr(self, "token_type_ids"):
121
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
122
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
123
+ token_type_ids = buffered_token_type_ids_expanded
124
+ else:
125
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
126
+
127
+ if inputs_embeds is None:
128
+ inputs_embeds = self.word_embeddings(input_ids)
129
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
130
+
131
+ embeddings = inputs_embeds + token_type_embeddings
132
+ if self.position_embedding_type == "absolute":
133
+ position_embeddings = self.position_embeddings(position_ids)
134
+ embeddings += position_embeddings
135
+ embeddings = self.LayerNorm(embeddings)
136
+ embeddings = self.dropout(embeddings)
137
+ return embeddings
138
+
139
+ def create_position_ids_from_inputs_embeds(self, inputs_embeds):
140
+ """
141
+ We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
142
+
143
+ Args:
144
+ inputs_embeds: torch.Tensor
145
+
146
+ Returns: torch.Tensor
147
+ """
148
+ input_shape = inputs_embeds.size()[:-1]
149
+ sequence_length = input_shape[1]
150
+
151
+ position_ids = torch.arange(
152
+ self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
153
+ )
154
+ return position_ids.unsqueeze(0).expand(input_shape)
155
+
156
+
157
+ # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Roberta
158
+ class RobertaSelfAttention(nn.Module):
159
+ def __init__(self, config, position_embedding_type=None):
160
+ super().__init__()
161
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
162
+ raise ValueError(
163
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
164
+ f"heads ({config.num_attention_heads})"
165
+ )
166
+
167
+ self.num_attention_heads = config.num_attention_heads
168
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
169
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
170
+
171
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
172
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
173
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
174
+
175
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
176
+ self.position_embedding_type = position_embedding_type or getattr(
177
+ config, "position_embedding_type", "absolute"
178
+ )
179
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
180
+ self.max_position_embeddings = config.max_position_embeddings
181
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
182
+
183
+ self.is_decoder = config.is_decoder
184
+
185
+ def get_attn(self):
186
+ return self.attn
187
+
188
+ def save_attn(self, attn):
189
+ self.attn = attn
190
+
191
+ def save_attn_cam(self, cam):
192
+ self.attn_cam = cam
193
+
194
+ def get_attn_cam(self):
195
+ return self.attn_cam
196
+
197
+ def save_attn_gradients(self, attn_gradients):
198
+ self.attn_gradients = attn_gradients
199
+
200
+ def get_attn_gradients(self):
201
+ return self.attn_gradients
202
+
203
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
204
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
205
+ x = x.view(new_x_shape)
206
+ return x.permute(0, 2, 1, 3)
207
+
208
+ def forward(
209
+ self,
210
+ hidden_states: torch.Tensor,
211
+ attention_mask: Optional[torch.FloatTensor] = None,
212
+ head_mask: Optional[torch.FloatTensor] = None,
213
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
214
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
215
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
216
+ output_attentions: Optional[bool] = False,
217
+ ) -> Tuple[torch.Tensor]:
218
+ mixed_query_layer = self.query(hidden_states)
219
+
220
+ # If this is instantiated as a cross-attention module, the keys
221
+ # and values come from an encoder; the attention mask needs to be
222
+ # such that the encoder's padding tokens are not attended to.
223
+ is_cross_attention = encoder_hidden_states is not None
224
+
225
+ if is_cross_attention and past_key_value is not None:
226
+ # reuse k,v, cross_attentions
227
+ key_layer = past_key_value[0]
228
+ value_layer = past_key_value[1]
229
+ attention_mask = encoder_attention_mask
230
+ elif is_cross_attention:
231
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
232
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
233
+ attention_mask = encoder_attention_mask
234
+ elif past_key_value is not None:
235
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
236
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
237
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
238
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
239
+ else:
240
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
241
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
242
+
243
+ query_layer = self.transpose_for_scores(mixed_query_layer)
244
+
245
+ if self.is_decoder:
246
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
247
+ # Further calls to cross_attention layer can then reuse all cross-attention
248
+ # key/value_states (first "if" case)
249
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
250
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
251
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
252
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
253
+ past_key_value = (key_layer, value_layer)
254
+
255
+ # Take the dot product between "query" and "key" to get the raw attention scores.
256
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
257
+
258
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
259
+ seq_length = hidden_states.size()[1]
260
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
261
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
262
+ distance = position_ids_l - position_ids_r
263
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
264
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
265
+
266
+ if self.position_embedding_type == "relative_key":
267
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
268
+ attention_scores = attention_scores + relative_position_scores
269
+ elif self.position_embedding_type == "relative_key_query":
270
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
271
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
272
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
273
+
274
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
275
+ if attention_mask is not None:
276
+ # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)
277
+ attention_scores = attention_scores + attention_mask
278
+
279
+ # Normalize the attention scores to probabilities.
280
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
281
+
282
+ self.save_attn(attention_probs)
283
+ attention_probs.register_hook(self.save_attn_gradients)
284
+
285
+ # This is actually dropping out entire tokens to attend to, which might
286
+ # seem a bit unusual, but is taken from the original Transformer paper.
287
+ attention_probs = self.dropout(attention_probs)
288
+
289
+ # Mask heads if we want to
290
+ if head_mask is not None:
291
+ attention_probs = attention_probs * head_mask
292
+
293
+ context_layer = torch.matmul(attention_probs, value_layer)
294
+
295
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
296
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
297
+ context_layer = context_layer.view(new_context_layer_shape)
298
+
299
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
300
+
301
+ if self.is_decoder:
302
+ outputs = outputs + (past_key_value,)
303
+ return outputs
304
+
305
+
306
+ # Copied from transformers.models.bert.modeling_bert.BertSelfOutput
307
+ class RobertaSelfOutput(nn.Module):
308
+ def __init__(self, config):
309
+ super().__init__()
310
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
311
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
312
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
313
+
314
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
315
+ hidden_states = self.dense(hidden_states)
316
+ hidden_states = self.dropout(hidden_states)
317
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
318
+ return hidden_states
319
+
320
+
321
+ # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Roberta
322
+ class RobertaAttention(nn.Module):
323
+ def __init__(self, config, position_embedding_type=None):
324
+ super().__init__()
325
+ self.self = RobertaSelfAttention(config, position_embedding_type=position_embedding_type)
326
+ self.output = RobertaSelfOutput(config)
327
+ self.pruned_heads = set()
328
+
329
+ def prune_heads(self, heads):
330
+ if len(heads) == 0:
331
+ return
332
+ heads, index = find_pruneable_heads_and_indices(
333
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
334
+ )
335
+
336
+ # Prune linear layers
337
+ self.self.query = prune_linear_layer(self.self.query, index)
338
+ self.self.key = prune_linear_layer(self.self.key, index)
339
+ self.self.value = prune_linear_layer(self.self.value, index)
340
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
341
+
342
+ # Update hyper params and store pruned heads
343
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
344
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
345
+ self.pruned_heads = self.pruned_heads.union(heads)
346
+
347
+ def forward(
348
+ self,
349
+ hidden_states: torch.Tensor,
350
+ attention_mask: Optional[torch.FloatTensor] = None,
351
+ head_mask: Optional[torch.FloatTensor] = None,
352
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
353
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
354
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
355
+ output_attentions: Optional[bool] = False,
356
+ ) -> Tuple[torch.Tensor]:
357
+ self_outputs = self.self(
358
+ hidden_states,
359
+ attention_mask,
360
+ head_mask,
361
+ encoder_hidden_states,
362
+ encoder_attention_mask,
363
+ past_key_value,
364
+ output_attentions,
365
+ )
366
+ attention_output = self.output(self_outputs[0], hidden_states)
367
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
368
+ return outputs
369
+
370
+
371
+ # Copied from transformers.models.bert.modeling_bert.BertIntermediate
372
+ class RobertaIntermediate(nn.Module):
373
+ def __init__(self, config):
374
+ super().__init__()
375
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
376
+ if isinstance(config.hidden_act, str):
377
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
378
+ else:
379
+ self.intermediate_act_fn = config.hidden_act
380
+
381
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
382
+ hidden_states = self.dense(hidden_states)
383
+ hidden_states = self.intermediate_act_fn(hidden_states)
384
+ return hidden_states
385
+
386
+
387
+ # Copied from transformers.models.bert.modeling_bert.BertOutput
388
+ class RobertaOutput(nn.Module):
389
+ def __init__(self, config):
390
+ super().__init__()
391
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
392
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
393
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
394
+
395
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
396
+ hidden_states = self.dense(hidden_states)
397
+ hidden_states = self.dropout(hidden_states)
398
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
399
+ return hidden_states
400
+
401
+
402
+ # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Roberta
403
+ class RobertaLayer(nn.Module):
404
+ def __init__(self, config):
405
+ super().__init__()
406
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
407
+ self.seq_len_dim = 1
408
+ self.attention = RobertaAttention(config)
409
+ self.is_decoder = config.is_decoder
410
+ self.add_cross_attention = config.add_cross_attention
411
+ if self.add_cross_attention:
412
+ if not self.is_decoder:
413
+ raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
414
+ self.crossattention = RobertaAttention(config, position_embedding_type="absolute")
415
+ self.intermediate = RobertaIntermediate(config)
416
+ self.output = RobertaOutput(config)
417
+
418
+ def forward(
419
+ self,
420
+ hidden_states: torch.Tensor,
421
+ attention_mask: Optional[torch.FloatTensor] = None,
422
+ head_mask: Optional[torch.FloatTensor] = None,
423
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
424
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
425
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
426
+ output_attentions: Optional[bool] = False,
427
+ ) -> Tuple[torch.Tensor]:
428
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
429
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
430
+ self_attention_outputs = self.attention(
431
+ hidden_states,
432
+ attention_mask,
433
+ head_mask,
434
+ output_attentions=output_attentions,
435
+ past_key_value=self_attn_past_key_value,
436
+ )
437
+ attention_output = self_attention_outputs[0]
438
+
439
+ # if decoder, the last output is tuple of self-attn cache
440
+ if self.is_decoder:
441
+ outputs = self_attention_outputs[1:-1]
442
+ present_key_value = self_attention_outputs[-1]
443
+ else:
444
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
445
+
446
+ cross_attn_present_key_value = None
447
+ if self.is_decoder and encoder_hidden_states is not None:
448
+ if not hasattr(self, "crossattention"):
449
+ raise ValueError(
450
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
451
+ " by setting `config.add_cross_attention=True`"
452
+ )
453
+
454
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
455
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
456
+ cross_attention_outputs = self.crossattention(
457
+ attention_output,
458
+ attention_mask,
459
+ head_mask,
460
+ encoder_hidden_states,
461
+ encoder_attention_mask,
462
+ cross_attn_past_key_value,
463
+ output_attentions,
464
+ )
465
+ attention_output = cross_attention_outputs[0]
466
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
467
+
468
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
469
+ cross_attn_present_key_value = cross_attention_outputs[-1]
470
+ present_key_value = present_key_value + cross_attn_present_key_value
471
+
472
+ layer_output = apply_chunking_to_forward(
473
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
474
+ )
475
+ outputs = (layer_output,) + outputs
476
+
477
+ # if decoder, return the attn key/values as the last output
478
+ if self.is_decoder:
479
+ outputs = outputs + (present_key_value,)
480
+
481
+ return outputs
482
+
483
+ def feed_forward_chunk(self, attention_output):
484
+ intermediate_output = self.intermediate(attention_output)
485
+ layer_output = self.output(intermediate_output, attention_output)
486
+ return layer_output
487
+
488
+
489
+ # Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->Roberta
490
+ class RobertaEncoder(nn.Module):
491
+ def __init__(self, config):
492
+ super().__init__()
493
+ self.config = config
494
+ self.layer = nn.ModuleList([RobertaLayer(config) for _ in range(config.num_hidden_layers)])
495
+ self.gradient_checkpointing = False
496
+
497
+ def forward(
498
+ self,
499
+ hidden_states: torch.Tensor,
500
+ attention_mask: Optional[torch.FloatTensor] = None,
501
+ head_mask: Optional[torch.FloatTensor] = None,
502
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
503
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
504
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
505
+ use_cache: Optional[bool] = None,
506
+ output_attentions: Optional[bool] = False,
507
+ output_hidden_states: Optional[bool] = False,
508
+ return_dict: Optional[bool] = True,
509
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
510
+ all_hidden_states = () if output_hidden_states else None
511
+ all_self_attentions = () if output_attentions else None
512
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
513
+
514
+ next_decoder_cache = () if use_cache else None
515
+ for i, layer_module in enumerate(self.layer):
516
+ if output_hidden_states:
517
+ all_hidden_states = all_hidden_states + (hidden_states,)
518
+
519
+ layer_head_mask = head_mask[i] if head_mask is not None else None
520
+ past_key_value = past_key_values[i] if past_key_values is not None else None
521
+
522
+ if self.gradient_checkpointing and self.training:
523
+
524
+ if use_cache:
525
+ logger.warning(
526
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
527
+ )
528
+ use_cache = False
529
+
530
+ def create_custom_forward(module):
531
+ def custom_forward(*inputs):
532
+ return module(*inputs, past_key_value, output_attentions)
533
+
534
+ return custom_forward
535
+
536
+ layer_outputs = torch.utils.checkpoint.checkpoint(
537
+ create_custom_forward(layer_module),
538
+ hidden_states,
539
+ attention_mask,
540
+ layer_head_mask,
541
+ encoder_hidden_states,
542
+ encoder_attention_mask,
543
+ )
544
+ else:
545
+ layer_outputs = layer_module(
546
+ hidden_states,
547
+ attention_mask,
548
+ layer_head_mask,
549
+ encoder_hidden_states,
550
+ encoder_attention_mask,
551
+ past_key_value,
552
+ output_attentions,
553
+ )
554
+
555
+ hidden_states = layer_outputs[0]
556
+ if use_cache:
557
+ next_decoder_cache += (layer_outputs[-1],)
558
+ if output_attentions:
559
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
560
+ if self.config.add_cross_attention:
561
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
562
+
563
+ if output_hidden_states:
564
+ all_hidden_states = all_hidden_states + (hidden_states,)
565
+
566
+ if not return_dict:
567
+ return tuple(
568
+ v
569
+ for v in [
570
+ hidden_states,
571
+ next_decoder_cache,
572
+ all_hidden_states,
573
+ all_self_attentions,
574
+ all_cross_attentions,
575
+ ]
576
+ if v is not None
577
+ )
578
+ return BaseModelOutputWithPastAndCrossAttentions(
579
+ last_hidden_state=hidden_states,
580
+ past_key_values=next_decoder_cache,
581
+ hidden_states=all_hidden_states,
582
+ attentions=all_self_attentions,
583
+ cross_attentions=all_cross_attentions,
584
+ )
585
+
586
+
587
+ # Copied from transformers.models.bert.modeling_bert.BertPooler
588
+ class RobertaPooler(nn.Module):
589
+ def __init__(self, config):
590
+ super().__init__()
591
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
592
+ self.activation = nn.Tanh()
593
+
594
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
595
+ # We "pool" the model by simply taking the hidden state corresponding
596
+ # to the first token.
597
+ first_token_tensor = hidden_states[:, 0]
598
+ pooled_output = self.dense(first_token_tensor)
599
+ pooled_output = self.activation(pooled_output)
600
+ return pooled_output
601
+
602
+
603
+ class RobertaPreTrainedModel(PreTrainedModel):
604
+ """
605
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
606
+ models.
607
+ """
608
+
609
+ config_class = RobertaConfig
610
+ base_model_prefix = "roberta"
611
+ supports_gradient_checkpointing = True
612
+
613
+ # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
614
+ def _init_weights(self, module):
615
+ """Initialize the weights"""
616
+ if isinstance(module, nn.Linear):
617
+ # Slightly different from the TF version which uses truncated_normal for initialization
618
+ # cf https://github.com/pytorch/pytorch/pull/5617
619
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
620
+ if module.bias is not None:
621
+ module.bias.data.zero_()
622
+ elif isinstance(module, nn.Embedding):
623
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
624
+ if module.padding_idx is not None:
625
+ module.weight.data[module.padding_idx].zero_()
626
+ elif isinstance(module, nn.LayerNorm):
627
+ module.bias.data.zero_()
628
+ module.weight.data.fill_(1.0)
629
+
630
+ def _set_gradient_checkpointing(self, module, value=False):
631
+ if isinstance(module, RobertaEncoder):
632
+ module.gradient_checkpointing = value
633
+
634
+ def update_keys_to_ignore(self, config, del_keys_to_ignore):
635
+ """Remove some keys from ignore list"""
636
+ if not config.tie_word_embeddings:
637
+ # must make a new list, or the class variable gets modified!
638
+ self._keys_to_ignore_on_save = [k for k in self._keys_to_ignore_on_save if k not in del_keys_to_ignore]
639
+ self._keys_to_ignore_on_load_missing = [
640
+ k for k in self._keys_to_ignore_on_load_missing if k not in del_keys_to_ignore
641
+ ]
642
+
643
+
644
+ ROBERTA_START_DOCSTRING = r"""
645
+
646
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
647
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
648
+ etc.)
649
+
650
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
651
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
652
+ and behavior.
653
+
654
+ Parameters:
655
+ config ([`RobertaConfig`]): Model configuration class with all the parameters of the
656
+ model. Initializing with a config file does not load the weights associated with the model, only the
657
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
658
+ """
659
+
660
+ ROBERTA_INPUTS_DOCSTRING = r"""
661
+ Args:
662
+ input_ids (`torch.LongTensor` of shape `({0})`):
663
+ Indices of input sequence tokens in the vocabulary.
664
+
665
+ Indices can be obtained using [`RobertaTokenizer`]. See [`PreTrainedTokenizer.encode`] and
666
+ [`PreTrainedTokenizer.__call__`] for details.
667
+
668
+ [What are input IDs?](../glossary#input-ids)
669
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
670
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
671
+
672
+ - 1 for tokens that are **not masked**,
673
+ - 0 for tokens that are **masked**.
674
+
675
+ [What are attention masks?](../glossary#attention-mask)
676
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
677
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
678
+ 1]`:
679
+
680
+ - 0 corresponds to a *sentence A* token,
681
+ - 1 corresponds to a *sentence B* token.
682
+
683
+ [What are token type IDs?](../glossary#token-type-ids)
684
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
685
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
686
+ config.max_position_embeddings - 1]`.
687
+
688
+ [What are position IDs?](../glossary#position-ids)
689
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
690
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
691
+
692
+ - 1 indicates the head is **not masked**,
693
+ - 0 indicates the head is **masked**.
694
+
695
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
696
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
697
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
698
+ model's internal embedding lookup matrix.
699
+ output_attentions (`bool`, *optional*):
700
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
701
+ tensors for more detail.
702
+ output_hidden_states (`bool`, *optional*):
703
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
704
+ more detail.
705
+ return_dict (`bool`, *optional*):
706
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
707
+ """
708
+
709
+
710
+ @add_start_docstrings(
711
+ "The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.",
712
+ ROBERTA_START_DOCSTRING,
713
+ )
714
+ class RobertaModel(RobertaPreTrainedModel):
715
+ """
716
+
717
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
718
+ cross-attention is added between the self-attention layers, following the architecture described in *Attention is
719
+ all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
720
+ Kaiser and Illia Polosukhin.
721
+
722
+ To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
723
+ to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
724
+ `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
725
+
726
+ .. _*Attention is all you need*: https://arxiv.org/abs/1706.03762
727
+
728
+ """
729
+
730
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
731
+
732
+ # Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Roberta
733
+ def __init__(self, config, add_pooling_layer=True):
734
+ super().__init__(config)
735
+ self.config = config
736
+
737
+ self.embeddings = RobertaEmbeddings(config)
738
+ self.encoder = RobertaEncoder(config)
739
+
740
+ self.pooler = RobertaPooler(config) if add_pooling_layer else None
741
+
742
+ # Initialize weights and apply final processing
743
+ self.post_init()
744
+
745
+ def get_input_embeddings(self):
746
+ return self.embeddings.word_embeddings
747
+
748
+ def set_input_embeddings(self, value):
749
+ self.embeddings.word_embeddings = value
750
+
751
+ def _prune_heads(self, heads_to_prune):
752
+ """
753
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
754
+ class PreTrainedModel
755
+ """
756
+ for layer, heads in heads_to_prune.items():
757
+ self.encoder.layer[layer].attention.prune_heads(heads)
758
+
759
+ @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
760
+ @add_code_sample_docstrings(
761
+ processor_class=_TOKENIZER_FOR_DOC,
762
+ checkpoint=_CHECKPOINT_FOR_DOC,
763
+ output_type=BaseModelOutputWithPoolingAndCrossAttentions,
764
+ config_class=_CONFIG_FOR_DOC,
765
+ )
766
+ # Copied from transformers.models.bert.modeling_bert.BertModel.forward
767
+ def forward(
768
+ self,
769
+ input_ids: Optional[torch.Tensor] = None,
770
+ attention_mask: Optional[torch.Tensor] = None,
771
+ token_type_ids: Optional[torch.Tensor] = None,
772
+ position_ids: Optional[torch.Tensor] = None,
773
+ head_mask: Optional[torch.Tensor] = None,
774
+ inputs_embeds: Optional[torch.Tensor] = None,
775
+ encoder_hidden_states: Optional[torch.Tensor] = None,
776
+ encoder_attention_mask: Optional[torch.Tensor] = None,
777
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
778
+ use_cache: Optional[bool] = None,
779
+ output_attentions: Optional[bool] = None,
780
+ output_hidden_states: Optional[bool] = None,
781
+ return_dict: Optional[bool] = None,
782
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
783
+ r"""
784
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
785
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
786
+ the model is configured as a decoder.
787
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
788
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
789
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
790
+
791
+ - 1 for tokens that are **not masked**,
792
+ - 0 for tokens that are **masked**.
793
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
794
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
795
+
796
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
797
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
798
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
799
+ use_cache (`bool`, *optional*):
800
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
801
+ `past_key_values`).
802
+ """
803
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
804
+ output_hidden_states = (
805
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
806
+ )
807
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
808
+
809
+ if self.config.is_decoder:
810
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
811
+ else:
812
+ use_cache = False
813
+
814
+ if input_ids is not None and inputs_embeds is not None:
815
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
816
+ elif input_ids is not None:
817
+ input_shape = input_ids.size()
818
+ elif inputs_embeds is not None:
819
+ input_shape = inputs_embeds.size()[:-1]
820
+ else:
821
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
822
+
823
+ batch_size, seq_length = input_shape
824
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
825
+
826
+ # past_key_values_length
827
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
828
+
829
+ if attention_mask is None:
830
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
831
+
832
+ if token_type_ids is None:
833
+ if hasattr(self.embeddings, "token_type_ids"):
834
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
835
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
836
+ token_type_ids = buffered_token_type_ids_expanded
837
+ else:
838
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
839
+
840
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
841
+ # ourselves in which case we just need to make it broadcastable to all heads.
842
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
843
+
844
+ # If a 2D or 3D attention mask is provided for the cross-attention
845
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
846
+ if self.config.is_decoder and encoder_hidden_states is not None:
847
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
848
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
849
+ if encoder_attention_mask is None:
850
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
851
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
852
+ else:
853
+ encoder_extended_attention_mask = None
854
+
855
+ # Prepare head mask if needed
856
+ # 1.0 in head_mask indicate we keep the head
857
+ # attention_probs has shape bsz x n_heads x N x N
858
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
859
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
860
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
861
+
862
+ embedding_output = self.embeddings(
863
+ input_ids=input_ids,
864
+ position_ids=position_ids,
865
+ token_type_ids=token_type_ids,
866
+ inputs_embeds=inputs_embeds,
867
+ past_key_values_length=past_key_values_length,
868
+ )
869
+ encoder_outputs = self.encoder(
870
+ embedding_output,
871
+ attention_mask=extended_attention_mask,
872
+ head_mask=head_mask,
873
+ encoder_hidden_states=encoder_hidden_states,
874
+ encoder_attention_mask=encoder_extended_attention_mask,
875
+ past_key_values=past_key_values,
876
+ use_cache=use_cache,
877
+ output_attentions=output_attentions,
878
+ output_hidden_states=output_hidden_states,
879
+ return_dict=return_dict,
880
+ )
881
+ sequence_output = encoder_outputs[0]
882
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
883
+
884
+ if not return_dict:
885
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
886
+
887
+ return BaseModelOutputWithPoolingAndCrossAttentions(
888
+ last_hidden_state=sequence_output,
889
+ pooler_output=pooled_output,
890
+ past_key_values=encoder_outputs.past_key_values,
891
+ hidden_states=encoder_outputs.hidden_states,
892
+ attentions=encoder_outputs.attentions,
893
+ cross_attentions=encoder_outputs.cross_attentions,
894
+ )
895
+
896
+
897
+ @add_start_docstrings(
898
+ """RoBERTa Model with a `language modeling` head on top for CLM fine-tuning.""", ROBERTA_START_DOCSTRING
899
+ )
900
+ class RobertaForCausalLM(RobertaPreTrainedModel):
901
+ _keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
902
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
903
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
904
+
905
+ def __init__(self, config):
906
+ super().__init__(config)
907
+
908
+ if not config.is_decoder:
909
+ logger.warning("If you want to use `RobertaLMHeadModel` as a standalone, add `is_decoder=True.`")
910
+
911
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
912
+ self.lm_head = RobertaLMHead(config)
913
+
914
+ # The LM head weights require special treatment only when they are tied with the word embeddings
915
+ self.update_keys_to_ignore(config, ["lm_head.decoder.weight"])
916
+
917
+ # Initialize weights and apply final processing
918
+ self.post_init()
919
+
920
+ def get_output_embeddings(self):
921
+ return self.lm_head.decoder
922
+
923
+ def set_output_embeddings(self, new_embeddings):
924
+ self.lm_head.decoder = new_embeddings
925
+
926
+ @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
927
+ @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
928
+ def forward(
929
+ self,
930
+ input_ids: Optional[torch.LongTensor] = None,
931
+ attention_mask: Optional[torch.FloatTensor] = None,
932
+ token_type_ids: Optional[torch.LongTensor] = None,
933
+ position_ids: Optional[torch.LongTensor] = None,
934
+ head_mask: Optional[torch.FloatTensor] = None,
935
+ inputs_embeds: Optional[torch.FloatTensor] = None,
936
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
937
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
938
+ labels: Optional[torch.LongTensor] = None,
939
+ past_key_values: Tuple[Tuple[torch.FloatTensor]] = None,
940
+ use_cache: Optional[bool] = None,
941
+ output_attentions: Optional[bool] = None,
942
+ output_hidden_states: Optional[bool] = None,
943
+ return_dict: Optional[bool] = None,
944
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
945
+ r"""
946
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
947
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
948
+ the model is configured as a decoder.
949
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
950
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
951
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
952
+
953
+ - 1 for tokens that are **not masked**,
954
+ - 0 for tokens that are **masked**.
955
+
956
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
957
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
958
+ `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
959
+ ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
960
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
961
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
962
+
963
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
964
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
965
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
966
+ use_cache (`bool`, *optional*):
967
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
968
+ `past_key_values`).
969
+
970
+ Returns:
971
+
972
+ Example:
973
+
974
+ ```python
975
+ >>> from transformers import RobertaTokenizer, RobertaForCausalLM, RobertaConfig
976
+ >>> import torch
977
+
978
+ >>> tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
979
+ >>> config = RobertaConfig.from_pretrained("roberta-base")
980
+ >>> config.is_decoder = True
981
+ >>> model = RobertaForCausalLM.from_pretrained("roberta-base", config=config)
982
+
983
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
984
+ >>> outputs = model(**inputs)
985
+
986
+ >>> prediction_logits = outputs.logits
987
+ ```"""
988
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
989
+ if labels is not None:
990
+ use_cache = False
991
+
992
+ outputs = self.roberta(
993
+ input_ids,
994
+ attention_mask=attention_mask,
995
+ token_type_ids=token_type_ids,
996
+ position_ids=position_ids,
997
+ head_mask=head_mask,
998
+ inputs_embeds=inputs_embeds,
999
+ encoder_hidden_states=encoder_hidden_states,
1000
+ encoder_attention_mask=encoder_attention_mask,
1001
+ past_key_values=past_key_values,
1002
+ use_cache=use_cache,
1003
+ output_attentions=output_attentions,
1004
+ output_hidden_states=output_hidden_states,
1005
+ return_dict=return_dict,
1006
+ )
1007
+
1008
+ sequence_output = outputs[0]
1009
+ prediction_scores = self.lm_head(sequence_output)
1010
+
1011
+ lm_loss = None
1012
+ if labels is not None:
1013
+ # we are doing next-token prediction; shift prediction scores and input ids by one
1014
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
1015
+ labels = labels[:, 1:].contiguous()
1016
+ loss_fct = CrossEntropyLoss()
1017
+ lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1018
+
1019
+ if not return_dict:
1020
+ output = (prediction_scores,) + outputs[2:]
1021
+ return ((lm_loss,) + output) if lm_loss is not None else output
1022
+
1023
+ return CausalLMOutputWithCrossAttentions(
1024
+ loss=lm_loss,
1025
+ logits=prediction_scores,
1026
+ past_key_values=outputs.past_key_values,
1027
+ hidden_states=outputs.hidden_states,
1028
+ attentions=outputs.attentions,
1029
+ cross_attentions=outputs.cross_attentions,
1030
+ )
1031
+
1032
+ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
1033
+ input_shape = input_ids.shape
1034
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1035
+ if attention_mask is None:
1036
+ attention_mask = input_ids.new_ones(input_shape)
1037
+
1038
+ # cut decoder_input_ids if past is used
1039
+ if past is not None:
1040
+ input_ids = input_ids[:, -1:]
1041
+
1042
+ return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
1043
+
1044
+ def _reorder_cache(self, past, beam_idx):
1045
+ reordered_past = ()
1046
+ for layer_past in past:
1047
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
1048
+ return reordered_past
1049
+
1050
+
1051
+ @add_start_docstrings("""RoBERTa Model with a `language modeling` head on top.""", ROBERTA_START_DOCSTRING)
1052
+ class RobertaForMaskedLM(RobertaPreTrainedModel):
1053
+ _keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
1054
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
1055
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1056
+
1057
+ def __init__(self, config):
1058
+ super().__init__(config)
1059
+
1060
+ if config.is_decoder:
1061
+ logger.warning(
1062
+ "If you want to use `RobertaForMaskedLM` make sure `config.is_decoder=False` for "
1063
+ "bi-directional self-attention."
1064
+ )
1065
+
1066
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
1067
+ self.lm_head = RobertaLMHead(config)
1068
+
1069
+ # The LM head weights require special treatment only when they are tied with the word embeddings
1070
+ self.update_keys_to_ignore(config, ["lm_head.decoder.weight"])
1071
+
1072
+ # Initialize weights and apply final processing
1073
+ self.post_init()
1074
+
1075
+ def get_output_embeddings(self):
1076
+ return self.lm_head.decoder
1077
+
1078
+ def set_output_embeddings(self, new_embeddings):
1079
+ self.lm_head.decoder = new_embeddings
1080
+
1081
+ @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1082
+ @add_code_sample_docstrings(
1083
+ processor_class=_TOKENIZER_FOR_DOC,
1084
+ checkpoint=_CHECKPOINT_FOR_DOC,
1085
+ output_type=MaskedLMOutput,
1086
+ config_class=_CONFIG_FOR_DOC,
1087
+ mask="<mask>",
1088
+ expected_output="' Paris'",
1089
+ expected_loss=0.1,
1090
+ )
1091
+ def forward(
1092
+ self,
1093
+ input_ids: Optional[torch.LongTensor] = None,
1094
+ attention_mask: Optional[torch.FloatTensor] = None,
1095
+ token_type_ids: Optional[torch.LongTensor] = None,
1096
+ position_ids: Optional[torch.LongTensor] = None,
1097
+ head_mask: Optional[torch.FloatTensor] = None,
1098
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1099
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1100
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1101
+ labels: Optional[torch.LongTensor] = None,
1102
+ output_attentions: Optional[bool] = None,
1103
+ output_hidden_states: Optional[bool] = None,
1104
+ return_dict: Optional[bool] = None,
1105
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
1106
+ r"""
1107
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1108
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
1109
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
1110
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
1111
+ kwargs (`Dict[str, any]`, optional, defaults to *{}*):
1112
+ Used to hide legacy arguments that have been deprecated.
1113
+ """
1114
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1115
+
1116
+ outputs = self.roberta(
1117
+ input_ids,
1118
+ attention_mask=attention_mask,
1119
+ token_type_ids=token_type_ids,
1120
+ position_ids=position_ids,
1121
+ head_mask=head_mask,
1122
+ inputs_embeds=inputs_embeds,
1123
+ encoder_hidden_states=encoder_hidden_states,
1124
+ encoder_attention_mask=encoder_attention_mask,
1125
+ output_attentions=output_attentions,
1126
+ output_hidden_states=output_hidden_states,
1127
+ return_dict=return_dict,
1128
+ )
1129
+ sequence_output = outputs[0]
1130
+ prediction_scores = self.lm_head(sequence_output)
1131
+
1132
+ masked_lm_loss = None
1133
+ if labels is not None:
1134
+ loss_fct = CrossEntropyLoss()
1135
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1136
+
1137
+ if not return_dict:
1138
+ output = (prediction_scores,) + outputs[2:]
1139
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1140
+
1141
+ return MaskedLMOutput(
1142
+ loss=masked_lm_loss,
1143
+ logits=prediction_scores,
1144
+ hidden_states=outputs.hidden_states,
1145
+ attentions=outputs.attentions,
1146
+ )
1147
+
1148
+
1149
+ class RobertaLMHead(nn.Module):
1150
+ """Roberta Head for masked language modeling."""
1151
+
1152
+ def __init__(self, config):
1153
+ super().__init__()
1154
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
1155
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1156
+
1157
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
1158
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
1159
+ self.decoder.bias = self.bias
1160
+
1161
+ def forward(self, features, **kwargs):
1162
+ x = self.dense(features)
1163
+ x = gelu(x)
1164
+ x = self.layer_norm(x)
1165
+
1166
+ # project back to size of vocabulary with bias
1167
+ x = self.decoder(x)
1168
+
1169
+ return x
1170
+
1171
+ def _tie_weights(self):
1172
+ # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
1173
+ self.bias = self.decoder.bias
1174
+
1175
+
1176
+ @add_start_docstrings(
1177
+ """
1178
+ RoBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the
1179
+ pooled output) e.g. for GLUE tasks.
1180
+ """,
1181
+ ROBERTA_START_DOCSTRING,
1182
+ )
1183
+ class RobertaForSequenceClassification(RobertaPreTrainedModel):
1184
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
1185
+
1186
+ def __init__(self, config):
1187
+ super().__init__(config)
1188
+ self.num_labels = config.num_labels
1189
+ self.config = config
1190
+
1191
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
1192
+ self.classifier = RobertaClassificationHead(config)
1193
+
1194
+ # Initialize weights and apply final processing
1195
+ self.post_init()
1196
+
1197
+ @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1198
+ @add_code_sample_docstrings(
1199
+ processor_class=_TOKENIZER_FOR_DOC,
1200
+ checkpoint="cardiffnlp/twitter-roberta-base-emotion",
1201
+ output_type=SequenceClassifierOutput,
1202
+ config_class=_CONFIG_FOR_DOC,
1203
+ expected_output="'optimism'",
1204
+ expected_loss=0.08,
1205
+ )
1206
+ def forward(
1207
+ self,
1208
+ input_ids: Optional[torch.LongTensor] = None,
1209
+ attention_mask: Optional[torch.FloatTensor] = None,
1210
+ token_type_ids: Optional[torch.LongTensor] = None,
1211
+ position_ids: Optional[torch.LongTensor] = None,
1212
+ head_mask: Optional[torch.FloatTensor] = None,
1213
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1214
+ labels: Optional[torch.LongTensor] = None,
1215
+ output_attentions: Optional[bool] = None,
1216
+ output_hidden_states: Optional[bool] = None,
1217
+ return_dict: Optional[bool] = None,
1218
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
1219
+ r"""
1220
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1221
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1222
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1223
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1224
+ """
1225
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1226
+
1227
+ outputs = self.roberta(
1228
+ input_ids,
1229
+ attention_mask=attention_mask,
1230
+ token_type_ids=token_type_ids,
1231
+ position_ids=position_ids,
1232
+ head_mask=head_mask,
1233
+ inputs_embeds=inputs_embeds,
1234
+ output_attentions=output_attentions,
1235
+ output_hidden_states=output_hidden_states,
1236
+ return_dict=return_dict,
1237
+ )
1238
+ sequence_output = outputs[0]
1239
+ logits = self.classifier(sequence_output)
1240
+
1241
+ loss = None
1242
+ if labels is not None:
1243
+ if self.config.problem_type is None:
1244
+ if self.num_labels == 1:
1245
+ self.config.problem_type = "regression"
1246
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1247
+ self.config.problem_type = "single_label_classification"
1248
+ else:
1249
+ self.config.problem_type = "multi_label_classification"
1250
+
1251
+ if self.config.problem_type == "regression":
1252
+ loss_fct = MSELoss()
1253
+ if self.num_labels == 1:
1254
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1255
+ else:
1256
+ loss = loss_fct(logits, labels)
1257
+ elif self.config.problem_type == "single_label_classification":
1258
+ loss_fct = CrossEntropyLoss()
1259
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1260
+ elif self.config.problem_type == "multi_label_classification":
1261
+ loss_fct = BCEWithLogitsLoss()
1262
+ loss = loss_fct(logits, labels)
1263
+
1264
+ if not return_dict:
1265
+ output = (logits,) + outputs[2:]
1266
+ return ((loss,) + output) if loss is not None else output
1267
+
1268
+ return SequenceClassifierOutput(
1269
+ loss=loss,
1270
+ logits=logits,
1271
+ hidden_states=outputs.hidden_states,
1272
+ attentions=outputs.attentions,
1273
+ )
1274
+
1275
+
1276
+ @add_start_docstrings(
1277
+ """
1278
+ Roberta Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
1279
+ softmax) e.g. for RocStories/SWAG tasks.
1280
+ """,
1281
+ ROBERTA_START_DOCSTRING,
1282
+ )
1283
+ class RobertaForMultipleChoice(RobertaPreTrainedModel):
1284
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
1285
+
1286
+ def __init__(self, config):
1287
+ super().__init__(config)
1288
+
1289
+ self.roberta = RobertaModel(config)
1290
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1291
+ self.classifier = nn.Linear(config.hidden_size, 1)
1292
+
1293
+ # Initialize weights and apply final processing
1294
+ self.post_init()
1295
+
1296
+ @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
1297
+ @add_code_sample_docstrings(
1298
+ processor_class=_TOKENIZER_FOR_DOC,
1299
+ checkpoint=_CHECKPOINT_FOR_DOC,
1300
+ output_type=MultipleChoiceModelOutput,
1301
+ config_class=_CONFIG_FOR_DOC,
1302
+ )
1303
+ def forward(
1304
+ self,
1305
+ input_ids: Optional[torch.LongTensor] = None,
1306
+ token_type_ids: Optional[torch.LongTensor] = None,
1307
+ attention_mask: Optional[torch.FloatTensor] = None,
1308
+ labels: Optional[torch.LongTensor] = None,
1309
+ position_ids: Optional[torch.LongTensor] = None,
1310
+ head_mask: Optional[torch.FloatTensor] = None,
1311
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1312
+ output_attentions: Optional[bool] = None,
1313
+ output_hidden_states: Optional[bool] = None,
1314
+ return_dict: Optional[bool] = None,
1315
+ ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
1316
+ r"""
1317
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1318
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
1319
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
1320
+ `input_ids` above)
1321
+ """
1322
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1323
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
1324
+
1325
+ flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
1326
+ flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
1327
+ flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
1328
+ flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1329
+ flat_inputs_embeds = (
1330
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
1331
+ if inputs_embeds is not None
1332
+ else None
1333
+ )
1334
+
1335
+ outputs = self.roberta(
1336
+ flat_input_ids,
1337
+ position_ids=flat_position_ids,
1338
+ token_type_ids=flat_token_type_ids,
1339
+ attention_mask=flat_attention_mask,
1340
+ head_mask=head_mask,
1341
+ inputs_embeds=flat_inputs_embeds,
1342
+ output_attentions=output_attentions,
1343
+ output_hidden_states=output_hidden_states,
1344
+ return_dict=return_dict,
1345
+ )
1346
+ pooled_output = outputs[1]
1347
+
1348
+ pooled_output = self.dropout(pooled_output)
1349
+ logits = self.classifier(pooled_output)
1350
+ reshaped_logits = logits.view(-1, num_choices)
1351
+
1352
+ loss = None
1353
+ if labels is not None:
1354
+ loss_fct = CrossEntropyLoss()
1355
+ loss = loss_fct(reshaped_logits, labels)
1356
+
1357
+ if not return_dict:
1358
+ output = (reshaped_logits,) + outputs[2:]
1359
+ return ((loss,) + output) if loss is not None else output
1360
+
1361
+ return MultipleChoiceModelOutput(
1362
+ loss=loss,
1363
+ logits=reshaped_logits,
1364
+ hidden_states=outputs.hidden_states,
1365
+ attentions=outputs.attentions,
1366
+ )
1367
+
1368
+
1369
+ @add_start_docstrings(
1370
+ """
1371
+ Roberta Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1372
+ Named-Entity-Recognition (NER) tasks.
1373
+ """,
1374
+ ROBERTA_START_DOCSTRING,
1375
+ )
1376
+ class RobertaForTokenClassification(RobertaPreTrainedModel):
1377
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1378
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
1379
+
1380
+ def __init__(self, config):
1381
+ super().__init__(config)
1382
+ self.num_labels = config.num_labels
1383
+
1384
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
1385
+ classifier_dropout = (
1386
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1387
+ )
1388
+ self.dropout = nn.Dropout(classifier_dropout)
1389
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1390
+
1391
+ # Initialize weights and apply final processing
1392
+ self.post_init()
1393
+
1394
+ @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1395
+ @add_code_sample_docstrings(
1396
+ processor_class=_TOKENIZER_FOR_DOC,
1397
+ checkpoint="Jean-Baptiste/roberta-large-ner-english",
1398
+ output_type=TokenClassifierOutput,
1399
+ config_class=_CONFIG_FOR_DOC,
1400
+ expected_output="['O', 'ORG', 'ORG', 'O', 'O', 'O', 'O', 'O', 'LOC', 'O', 'LOC', 'LOC']",
1401
+ expected_loss=0.01,
1402
+ )
1403
+ def forward(
1404
+ self,
1405
+ input_ids: Optional[torch.LongTensor] = None,
1406
+ attention_mask: Optional[torch.FloatTensor] = None,
1407
+ token_type_ids: Optional[torch.LongTensor] = None,
1408
+ position_ids: Optional[torch.LongTensor] = None,
1409
+ head_mask: Optional[torch.FloatTensor] = None,
1410
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1411
+ labels: Optional[torch.LongTensor] = None,
1412
+ output_attentions: Optional[bool] = None,
1413
+ output_hidden_states: Optional[bool] = None,
1414
+ return_dict: Optional[bool] = None,
1415
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1416
+ r"""
1417
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1418
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
1419
+ """
1420
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1421
+
1422
+ outputs = self.roberta(
1423
+ input_ids,
1424
+ attention_mask=attention_mask,
1425
+ token_type_ids=token_type_ids,
1426
+ position_ids=position_ids,
1427
+ head_mask=head_mask,
1428
+ inputs_embeds=inputs_embeds,
1429
+ output_attentions=output_attentions,
1430
+ output_hidden_states=output_hidden_states,
1431
+ return_dict=return_dict,
1432
+ )
1433
+
1434
+ sequence_output = outputs[0]
1435
+
1436
+ sequence_output = self.dropout(sequence_output)
1437
+ logits = self.classifier(sequence_output)
1438
+
1439
+ loss = None
1440
+ if labels is not None:
1441
+ loss_fct = CrossEntropyLoss()
1442
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1443
+
1444
+ if not return_dict:
1445
+ output = (logits,) + outputs[2:]
1446
+ return ((loss,) + output) if loss is not None else output
1447
+
1448
+ return TokenClassifierOutput(
1449
+ loss=loss,
1450
+ logits=logits,
1451
+ hidden_states=outputs.hidden_states,
1452
+ attentions=outputs.attentions,
1453
+ )
1454
+
1455
+
1456
+ class RobertaClassificationHead(nn.Module):
1457
+ """Head for sentence-level classification tasks."""
1458
+
1459
+ def __init__(self, config):
1460
+ super().__init__()
1461
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
1462
+ classifier_dropout = (
1463
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1464
+ )
1465
+ self.dropout = nn.Dropout(classifier_dropout)
1466
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
1467
+
1468
+ def forward(self, features, **kwargs):
1469
+ x = features[:, 0, :] # take <s> token (equiv. to [CLS])
1470
+ x = self.dropout(x)
1471
+ x = self.dense(x)
1472
+ x = torch.tanh(x)
1473
+ x = self.dropout(x)
1474
+ x = self.out_proj(x)
1475
+ return x
1476
+
1477
+
1478
+ @add_start_docstrings(
1479
+ """
1480
+ Roberta Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
1481
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1482
+ """,
1483
+ ROBERTA_START_DOCSTRING,
1484
+ )
1485
+ class RobertaForQuestionAnswering(RobertaPreTrainedModel):
1486
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1487
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
1488
+
1489
+ def __init__(self, config):
1490
+ super().__init__(config)
1491
+ self.num_labels = config.num_labels
1492
+
1493
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
1494
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1495
+
1496
+ # Initialize weights and apply final processing
1497
+ self.post_init()
1498
+
1499
+ @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1500
+ @add_code_sample_docstrings(
1501
+ processor_class=_TOKENIZER_FOR_DOC,
1502
+ checkpoint="deepset/roberta-base-squad2",
1503
+ output_type=QuestionAnsweringModelOutput,
1504
+ config_class=_CONFIG_FOR_DOC,
1505
+ expected_output="' puppet'",
1506
+ expected_loss=0.86,
1507
+ )
1508
+ def forward(
1509
+ self,
1510
+ input_ids: Optional[torch.LongTensor] = None,
1511
+ attention_mask: Optional[torch.FloatTensor] = None,
1512
+ token_type_ids: Optional[torch.LongTensor] = None,
1513
+ position_ids: Optional[torch.LongTensor] = None,
1514
+ head_mask: Optional[torch.FloatTensor] = None,
1515
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1516
+ start_positions: Optional[torch.LongTensor] = None,
1517
+ end_positions: Optional[torch.LongTensor] = None,
1518
+ output_attentions: Optional[bool] = None,
1519
+ output_hidden_states: Optional[bool] = None,
1520
+ return_dict: Optional[bool] = None,
1521
+ ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
1522
+ r"""
1523
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1524
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1525
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1526
+ are not taken into account for computing the loss.
1527
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1528
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1529
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1530
+ are not taken into account for computing the loss.
1531
+ """
1532
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1533
+
1534
+ outputs = self.roberta(
1535
+ input_ids,
1536
+ attention_mask=attention_mask,
1537
+ token_type_ids=token_type_ids,
1538
+ position_ids=position_ids,
1539
+ head_mask=head_mask,
1540
+ inputs_embeds=inputs_embeds,
1541
+ output_attentions=output_attentions,
1542
+ output_hidden_states=output_hidden_states,
1543
+ return_dict=return_dict,
1544
+ )
1545
+
1546
+ sequence_output = outputs[0]
1547
+
1548
+ logits = self.qa_outputs(sequence_output)
1549
+ start_logits, end_logits = logits.split(1, dim=-1)
1550
+ start_logits = start_logits.squeeze(-1).contiguous()
1551
+ end_logits = end_logits.squeeze(-1).contiguous()
1552
+
1553
+ total_loss = None
1554
+ if start_positions is not None and end_positions is not None:
1555
+ # If we are on multi-GPU, split add a dimension
1556
+ if len(start_positions.size()) > 1:
1557
+ start_positions = start_positions.squeeze(-1)
1558
+ if len(end_positions.size()) > 1:
1559
+ end_positions = end_positions.squeeze(-1)
1560
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1561
+ ignored_index = start_logits.size(1)
1562
+ start_positions = start_positions.clamp(0, ignored_index)
1563
+ end_positions = end_positions.clamp(0, ignored_index)
1564
+
1565
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1566
+ start_loss = loss_fct(start_logits, start_positions)
1567
+ end_loss = loss_fct(end_logits, end_positions)
1568
+ total_loss = (start_loss + end_loss) / 2
1569
+
1570
+ if not return_dict:
1571
+ output = (start_logits, end_logits) + outputs[2:]
1572
+ return ((total_loss,) + output) if total_loss is not None else output
1573
+
1574
+ return QuestionAnsweringModelOutput(
1575
+ loss=total_loss,
1576
+ start_logits=start_logits,
1577
+ end_logits=end_logits,
1578
+ hidden_states=outputs.hidden_states,
1579
+ attentions=outputs.attentions,
1580
+ )
1581
+
1582
+
1583
+ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
1584
+ """
1585
+ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
1586
+ are ignored. This is modified from fairseq's `utils.make_positions`.
1587
+
1588
+ Args:
1589
+ x: torch.Tensor x:
1590
+
1591
+ Returns: torch.Tensor
1592
+ """
1593
+ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
1594
+ mask = input_ids.ne(padding_idx).int()
1595
+ incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
1596
+ return incremental_indices.long() + padding_idx
BERT_explainability/roberta2.py.rej ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --- modeling_roberta.py 2022-06-28 11:59:19.974278244 +0200
2
+ +++ roberta2.py 2022-06-28 14:13:05.765050058 +0200
3
+ @@ -23,14 +23,14 @@
4
+ from torch import nn
5
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
6
+
7
+ -from ...activations import ACT2FN, gelu
8
+ -from ...file_utils import (
9
+ +from transformers.activations import ACT2FN, gelu
10
+ +from transformers.file_utils import (
11
+ add_code_sample_docstrings,
12
+ add_start_docstrings,
13
+ add_start_docstrings_to_model_forward,
14
+ replace_return_docstrings,
15
+ )
16
+ -from ...modeling_outputs import (
17
+ +from transformers.modeling_outputs import (
18
+ BaseModelOutputWithPastAndCrossAttentions,
19
+ BaseModelOutputWithPoolingAndCrossAttentions,
20
+ CausalLMOutputWithCrossAttentions,
21
+ @@ -40,14 +40,14 @@
22
+ SequenceClassifierOutput,
23
+ TokenClassifierOutput,
24
+ )
25
+ -from ...modeling_utils import (
26
+ +from transformers.modeling_utils import (
27
+ PreTrainedModel,
28
+ apply_chunking_to_forward,
29
+ find_pruneable_heads_and_indices,
30
+ prune_linear_layer,
31
+ )
32
+ -from ...utils import logging
33
+ -from .configuration_roberta import RobertaConfig
34
+ +from transformers.utils import logging
35
+ +from transformers.models.roberta.configuration_roberta import RobertaConfig
36
+
37
+
38
+ logger = logging.get_logger(__name__)
39
+ @@ -183,6 +183,24 @@
40
+
41
+ self.is_decoder = config.is_decoder
42
+
43
+ + def get_attn(self):
44
+ + return self.attn
45
+ +
46
+ + def save_attn(self, attn):
47
+ + self.attn = attn
48
+ +
49
+ + def save_attn_cam(self, cam):
50
+ + self.attn_cam = cam
51
+ +
52
+ + def get_attn_cam(self):
53
+ + return self.attn_cam
54
+ +
55
+ + def save_attn_gradients(self, attn_gradients):
56
+ + self.attn_gradients = attn_gradients
57
+ +
58
+ + def get_attn_gradients(self):
59
+ + return self.attn_gradients
60
+ +
61
+ def transpose_for_scores(self, x):
62
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
63
+ x = x.view(*new_x_shape)
app.py CHANGED
@@ -1,7 +1,194 @@
 
1
  import gradio
2
 
3
- def greet(name):
4
- return f"Hello {name}. Check back soon for real content"
5
 
6
- iface = gradio.Interface(fn=greet, inputs="text", outputs="text")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  iface.launch()
 
1
+ import sys
2
  import gradio
3
 
4
+ sys.path.append("BERT_explainability")
 
5
 
6
+ import torch
7
+
8
+ from BERT_explainability.ExplanationGenerator import Generator
9
+ from BERT_explainability.roberta2 import RobertaForSequenceClassification
10
+ from transformers import AutoTokenizer
11
+
12
+ from captum.attr import (
13
+ visualization
14
+ )
15
+ import torch
16
+
17
+ # from https://discuss.pytorch.org/t/using-scikit-learns-scalers-for-torchvision/53455
18
+ class PyTMinMaxScalerVectorized(object):
19
+ """
20
+ Transforms each channel to the range [0, 1].
21
+ """
22
+ def __init__(self, dimension=-1):
23
+ self.d = dimension
24
+ def __call__(self, tensor):
25
+ d = self.d
26
+ scale = 1.0 / (tensor.max(dim=d, keepdim=True)[0] - tensor.min(dim=d, keepdim=True)[0])
27
+ tensor.mul_(scale).sub_(tensor.min(dim=d, keepdim=True)[0])
28
+ return tensor
29
+
30
+
31
+ if torch.cuda.is_available():
32
+ device = torch.device("cuda")
33
+ else:
34
+ device = torch.device("cpu")
35
+
36
+ model = RobertaForSequenceClassification.from_pretrained("textattack/roberta-base-SST-2").to(device)
37
+ model.eval()
38
+ tokenizer = AutoTokenizer.from_pretrained("textattack/roberta-base-SST-2")
39
+ # initialize the explanations generator
40
+ explanations = Generator(model, "roberta")
41
+
42
+ classifications = ["NEGATIVE", "POSITIVE"]
43
+
44
+ # rule 5 from paper
45
+ def avg_heads(cam, grad):
46
+ cam = (
47
+ (grad * cam)
48
+ .clamp(min=0)
49
+ .mean(dim=-3)
50
+ )
51
+ # set negative values to 0, then average
52
+ # cam = cam.clamp(min=0).mean(dim=0)
53
+ return cam
54
+
55
+ # rule 6 from paper
56
+ def apply_self_attention_rules(R_ss, cam_ss):
57
+ R_ss_addition = torch.matmul(cam_ss, R_ss)
58
+ return R_ss_addition
59
+
60
+ def generate_relevance(model, input_ids, attention_mask, index=None, start_layer=0):
61
+ output = model(input_ids=input_ids, attention_mask=attention_mask)[0]
62
+ if index == None:
63
+ #index = np.expand_dims(np.arange(input_ids.shape[1])
64
+ # by default explain the class with the highest score
65
+ index = output.argmax(axis=-1).detach().cpu().numpy()
66
+
67
+ # create a one-hot vector selecting class we want explanations for
68
+ one_hot = (torch.nn.functional
69
+ .one_hot(torch.tensor(index, dtype=torch.int64), num_classes=output.size(-1))
70
+ .to(torch.float)
71
+ .requires_grad_(True)
72
+ ).to(device)
73
+ print("ONE_HOT", one_hot.size(), one_hot)
74
+ one_hot = torch.sum(one_hot * output)
75
+ model.zero_grad()
76
+ # create the gradients for the class we're interested in
77
+ one_hot.backward(retain_graph=True)
78
+
79
+ num_tokens = model.roberta.encoder.layer[0].attention.self.get_attn().shape[-1]
80
+ print(input_ids.size(-1), num_tokens)
81
+ R = torch.eye(num_tokens).expand(output.size(0), -1, -1).clone().to(device)
82
+
83
+ for i, blk in enumerate(model.roberta.encoder.layer):
84
+ if i < start_layer:
85
+ continue
86
+ grad = blk.attention.self.get_attn_gradients()
87
+ cam = blk.attention.self.get_attn()
88
+ cam = avg_heads(cam, grad)
89
+ joint = apply_self_attention_rules(R, cam)
90
+ R += joint
91
+ return output, R[:, 0, 1:-1]
92
+
93
+ def visualize_text(datarecords, legend=True):
94
+ dom = ["<table width: 100%>"]
95
+ rows = [
96
+ "<tr><th>True Label</th>"
97
+ "<th>Predicted Label</th>"
98
+ "<th>Attribution Label</th>"
99
+ "<th>Attribution Score</th>"
100
+ "<th>Word Importance</th>"
101
+ ]
102
+ for datarecord in datarecords:
103
+ rows.append(
104
+ "".join(
105
+ [
106
+ "<tr>",
107
+ format_classname(datarecord.true_class),
108
+ format_classname(
109
+ "{0} ({1:.2f})".format(
110
+ datarecord.pred_class, datarecord.pred_prob
111
+ )
112
+ ),
113
+ format_classname(datarecord.attr_class),
114
+ format_classname("{0:.2f}".format(datarecord.attr_score)),
115
+ format_word_importances(
116
+ datarecord.raw_input_ids, datarecord.word_attributions
117
+ ),
118
+ "<tr>",
119
+ ]
120
+ )
121
+ )
122
+
123
+ if legend:
124
+ dom.append(
125
+ '<div style="border-top: 1px solid; margin-top: 5px; \
126
+ padding-top: 5px; display: inline-block">'
127
+ )
128
+ dom.append("<b>Legend: </b>")
129
+
130
+ for value, label in zip([-1, 0, 1], ["Negative", "Neutral", "Positive"]):
131
+ dom.append(
132
+ '<span style="display: inline-block; width: 10px; height: 10px; \
133
+ border: 1px solid; background-color: \
134
+ {value}"></span> {label} '.format(
135
+ value=_get_color(value), label=label
136
+ )
137
+ )
138
+ dom.append("</div>")
139
+
140
+ dom.append("".join(rows))
141
+ dom.append("</table>")
142
+ html = "".join(dom)
143
+
144
+ return html
145
+
146
+ def show_explanation(model, input_ids, attention_mask, index=None, start_layer=0):
147
+ # generate an explanation for the input
148
+ output, expl = generate_relevance(model, input_ids, attention_mask, index=index, start_layer=start_layer)
149
+ print(output.shape, expl.shape)
150
+ # normalize scores
151
+ scaler = PyTMinMaxScalerVectorized()
152
+
153
+ norm = scaler(expl)
154
+ # get the model classification
155
+ output = torch.nn.functional.softmax(output, dim=-1)
156
+
157
+
158
+ vis_data_records = []
159
+ for record in range(input_ids.size(0)):
160
+ classification = output[record].argmax(dim=-1).item()
161
+ class_name = classifications[classification]
162
+ nrm = norm[record]
163
+
164
+ # if the classification is negative, higher explanation scores are more negative
165
+ # flip for visualization
166
+ if class_name == "NEGATIVE":
167
+ nrm *= (-1)
168
+ tokens = tokenizer.convert_ids_to_tokens(input_ids[record].flatten())[1:0 - ((attention_mask[record] == 0).sum().item() + 1)]
169
+ print([(tokens[i], nrm[i].item()) for i in range(len(tokens))])
170
+ vis_data_records.append(visualization.VisualizationDataRecord(
171
+ nrm,
172
+ output[record][classification],
173
+ classification,
174
+ classification,
175
+ index,
176
+ 1,
177
+ tokens,
178
+ 1))
179
+ return visualize_text(vis_data_records)
180
+
181
+ def run(input_text):
182
+ text_batch = [input_text]
183
+ encoding = tokenizer(text_batch, return_tensors='pt')
184
+ input_ids = encoding['input_ids'].to(device)
185
+ attention_mask = encoding['attention_mask'].to(device)
186
+
187
+ # true class is positive - 1
188
+ true_class = 1
189
+
190
+ html = show_explanation(model, input_ids, attention_mask)
191
+ return html
192
+
193
+ iface = gradio.Interface(fn=greet, inputs="text", outputs="html", examples=[["This movie was the best movie I have ever seen! some scenes were ridiculous, but acting was great"], ["I really didn't like this movie. Some of the actors were good, but overall the movie was boring"]])
194
  iface.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ pytorch
2
+ transformers==4.21.2
3
+ captum
4
+