davda54 commited on
Commit
1e8485b
·
1 Parent(s): 450b6bd

Optimized wrapper with correct API

Browse files
Files changed (1) hide show
  1. modeling_norbert.py +64 -36
modeling_norbert.py CHANGED
@@ -1,12 +1,9 @@
1
- from __future__ import absolute_import, division, print_function, unicode_literals
2
-
3
  import math
4
  from typing import List, Optional, Tuple, Union
5
 
6
  import torch
7
  import torch.nn as nn
8
  import torch.nn.functional as F
9
- from torch import _softmax_backward_data as _softmax_backward_data
10
  from torch.utils import checkpoint
11
 
12
  from configuration_norbert import NorbertConfig
@@ -20,6 +17,7 @@ from transformers.modeling_outputs import (
20
  TokenClassifierOutput,
21
  BaseModelOutput
22
  )
 
23
 
24
 
25
  class Encoder(nn.Module):
@@ -130,8 +128,8 @@ class MaskedSoftmax(torch.autograd.Function):
130
  @staticmethod
131
  def backward(self, grad_output):
132
  output, = self.saved_tensors
133
- inputGrad = _softmax_backward_data(grad_output, output, self.dim, output.dtype)
134
- return inputGrad, None, None
135
 
136
 
137
  class Attention(nn.Module):
@@ -188,31 +186,36 @@ class Attention(nn.Module):
188
  if self.position_indices.size(0) < query_len:
189
  position_indices = torch.arange(query_len, dtype=torch.long).unsqueeze(1) \
190
  - torch.arange(query_len, dtype=torch.long).unsqueeze(0)
191
- position_indices = self.make_log_bucket_position(position_indices, self.config.position_bucket_size, 512)
192
- position_indices = self.config.position_bucket_size - 1 + position_indices
193
- self.register_buffer("position_indices", position_indices.to(hidden_states.device), persistent=True)
194
 
195
  hidden_states = self.pre_layer_norm(hidden_states)
196
 
197
  query, key = self.in_proj_qk(hidden_states).chunk(2, dim=2) # shape: [T, B, D]
198
  value = self.in_proj_v(hidden_states) # shape: [T, B, D]
199
 
200
- pos = self.in_proj_qk(self.dropout(relative_embedding)) # shape: [2T-1, 2D]
201
- pos = F.embedding(self.position_indices[:query_len, :key_len], pos) # shape: [T, T, 2D]
202
- pos = pos.view(query_len, key_len, self.num_heads, 2*self.head_size)
203
- query_pos, key_pos = pos.chunk(2, dim=3)
204
-
205
  query = query.reshape(query_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
206
  key = key.reshape(key_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
207
  value = value.view(key_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
208
 
209
  attention_scores = torch.bmm(query, key.transpose(1, 2) * self.scale)
210
 
 
 
211
  query = query.view(batch_size, self.num_heads, query_len, self.head_size)
212
  key = key.view(batch_size, self.num_heads, query_len, self.head_size)
 
 
 
 
 
 
 
 
213
  attention_scores = attention_scores.view(batch_size, self.num_heads, query_len, key_len)
214
- attention_scores.add_(torch.einsum("bhqd,qkhd->bhqk", query, key_pos * self.scale))
215
- attention_scores.add_(torch.einsum("bhkd,qkhd->bhqk", key * self.scale, query_pos))
216
 
217
  return attention_scores, value
218
 
@@ -332,12 +335,16 @@ class NorbertModel(NorbertPreTrainedModel):
332
  sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask)
333
 
334
  if not return_dict:
335
- return sequence_output, contextualized_embeddings, attention_probs
 
 
 
 
336
 
337
  return BaseModelOutput(
338
  last_hidden_state=sequence_output,
339
- hidden_states=contextualized_embeddings,
340
- attentions=attention_probs
341
  )
342
 
343
 
@@ -375,14 +382,18 @@ class NorbertForMaskedLM(NorbertModel):
375
  masked_lm_loss = F.cross_entropy(subword_prediction.flatten(0, 1), labels.flatten())
376
 
377
  if not return_dict:
378
- output = (subword_prediction, contextualized_embeddings, attention_probs)
 
 
 
 
379
  return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
380
 
381
  return MaskedLMOutput(
382
  loss=masked_lm_loss,
383
  logits=subword_prediction,
384
- hidden_states=contextualized_embeddings,
385
- attentions=attention_probs
386
  )
387
 
388
 
@@ -465,14 +476,18 @@ class NorbertForSequenceClassification(NorbertModel):
465
  loss = loss_fct(logits, labels)
466
 
467
  if not return_dict:
468
- output = (logits, contextualized_embeddings, attention_probs)
 
 
 
 
469
  return ((loss,) + output) if loss is not None else output
470
 
471
  return SequenceClassifierOutput(
472
  loss=loss,
473
  logits=logits,
474
- hidden_states=contextualized_embeddings,
475
- attentions=attention_probs
476
  )
477
 
478
 
@@ -508,14 +523,18 @@ class NorbertForTokenClassification(NorbertModel):
508
  loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
509
 
510
  if not return_dict:
511
- output = (logits, contextualized_embeddings, attention_probs)
 
 
 
 
512
  return ((loss,) + output) if loss is not None else output
513
 
514
  return TokenClassifierOutput(
515
  loss=loss,
516
  logits=logits,
517
- hidden_states=contextualized_embeddings,
518
- attentions=attention_probs
519
  )
520
 
521
 
@@ -569,15 +588,20 @@ class NorbertForQuestionAnswering(NorbertModel):
569
  total_loss = (start_loss + end_loss) / 2
570
 
571
  if not return_dict:
572
- output = start_logits, end_logits, contextualized_embeddings, attention_probs
 
 
 
 
 
573
  return ((total_loss,) + output) if total_loss is not None else output
574
 
575
  return QuestionAnsweringModelOutput(
576
  loss=total_loss,
577
  start_logits=start_logits,
578
  end_logits=end_logits,
579
- hidden_states=contextualized_embeddings,
580
- attentions=attention_probs,
581
  )
582
 
583
 
@@ -598,9 +622,9 @@ class NorbertForMultipleChoice(NorbertModel):
598
  token_type_ids: Optional[torch.Tensor] = None,
599
  position_ids: Optional[torch.Tensor] = None,
600
  labels: Optional[torch.Tensor] = None,
601
- return_dict: Optional[bool] = None,
602
- start_positions: Optional[torch.Tensor] = None,
603
- end_positions: Optional[torch.Tensor] = None
604
  ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
605
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
606
  num_choices = input_ids.shape[1]
@@ -618,12 +642,16 @@ class NorbertForMultipleChoice(NorbertModel):
618
  loss = loss_fct(reshaped_logits, labels)
619
 
620
  if not return_dict:
621
- output = (reshaped_logits, contextualized_embeddings, attention_probs)
 
 
 
 
622
  return ((loss,) + output) if loss is not None else output
623
 
624
  return MultipleChoiceModelOutput(
625
  loss=loss,
626
  logits=reshaped_logits,
627
- hidden_states=contextualized_embeddings,
628
- attentions=attention_probs,
629
  )
 
 
 
1
  import math
2
  from typing import List, Optional, Tuple, Union
3
 
4
  import torch
5
  import torch.nn as nn
6
  import torch.nn.functional as F
 
7
  from torch.utils import checkpoint
8
 
9
  from configuration_norbert import NorbertConfig
 
17
  TokenClassifierOutput,
18
  BaseModelOutput
19
  )
20
+ from transformers.pytorch_utils import softmax_backward_data
21
 
22
 
23
  class Encoder(nn.Module):
 
128
  @staticmethod
129
  def backward(self, grad_output):
130
  output, = self.saved_tensors
131
+ input_grad = softmax_backward_data(self, grad_output, output, self.dim, output)
132
+ return input_grad, None, None
133
 
134
 
135
  class Attention(nn.Module):
 
186
  if self.position_indices.size(0) < query_len:
187
  position_indices = torch.arange(query_len, dtype=torch.long).unsqueeze(1) \
188
  - torch.arange(query_len, dtype=torch.long).unsqueeze(0)
189
+ position_indices = self.make_log_bucket_position(position_indices, self.position_bucket_size, 512)
190
+ position_indices = self.position_bucket_size - 1 + position_indices
191
+ self.position_indices = position_indices.to(hidden_states.device)
192
 
193
  hidden_states = self.pre_layer_norm(hidden_states)
194
 
195
  query, key = self.in_proj_qk(hidden_states).chunk(2, dim=2) # shape: [T, B, D]
196
  value = self.in_proj_v(hidden_states) # shape: [T, B, D]
197
 
 
 
 
 
 
198
  query = query.reshape(query_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
199
  key = key.reshape(key_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
200
  value = value.view(key_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
201
 
202
  attention_scores = torch.bmm(query, key.transpose(1, 2) * self.scale)
203
 
204
+ pos = self.in_proj_qk(self.dropout(relative_embedding)) # shape: [2T-1, 2D]
205
+ query_pos, key_pos = pos.view(-1, self.num_heads, 2*self.head_size).chunk(2, dim=2)
206
  query = query.view(batch_size, self.num_heads, query_len, self.head_size)
207
  key = key.view(batch_size, self.num_heads, query_len, self.head_size)
208
+
209
+ attention_c_p = torch.einsum("bhqd,khd->bhqk", query, key_pos.squeeze(1) * self.scale)
210
+ attention_p_c = torch.einsum("bhkd,qhd->bhqk", key * self.scale, query_pos.squeeze(1))
211
+
212
+ position_indices = self.position_indices[:query_len, :key_len].expand(batch_size, self.num_heads, -1, -1)
213
+ attention_c_p = attention_c_p.gather(3, position_indices)
214
+ attention_p_c = attention_p_c.gather(2, position_indices)
215
+
216
  attention_scores = attention_scores.view(batch_size, self.num_heads, query_len, key_len)
217
+ attention_scores.add_(attention_c_p)
218
+ attention_scores.add_(attention_p_c)
219
 
220
  return attention_scores, value
221
 
 
335
  sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask)
336
 
337
  if not return_dict:
338
+ return (
339
+ sequence_output,
340
+ *([contextualized_embeddings] if output_hidden_states else []),
341
+ *([attention_probs] if output_attentions else [])
342
+ )
343
 
344
  return BaseModelOutput(
345
  last_hidden_state=sequence_output,
346
+ hidden_states=contextualized_embeddings if output_hidden_states else None,
347
+ attentions=attention_probs if output_attentions else None
348
  )
349
 
350
 
 
382
  masked_lm_loss = F.cross_entropy(subword_prediction.flatten(0, 1), labels.flatten())
383
 
384
  if not return_dict:
385
+ output = (
386
+ subword_prediction,
387
+ *([contextualized_embeddings] if output_hidden_states else []),
388
+ *([attention_probs] if output_attentions else [])
389
+ )
390
  return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
391
 
392
  return MaskedLMOutput(
393
  loss=masked_lm_loss,
394
  logits=subword_prediction,
395
+ hidden_states=contextualized_embeddings if output_hidden_states else None,
396
+ attentions=attention_probs if output_attentions else None
397
  )
398
 
399
 
 
476
  loss = loss_fct(logits, labels)
477
 
478
  if not return_dict:
479
+ output = (
480
+ logits,
481
+ *([contextualized_embeddings] if output_hidden_states else []),
482
+ *([attention_probs] if output_attentions else [])
483
+ )
484
  return ((loss,) + output) if loss is not None else output
485
 
486
  return SequenceClassifierOutput(
487
  loss=loss,
488
  logits=logits,
489
+ hidden_states=contextualized_embeddings if output_hidden_states else None,
490
+ attentions=attention_probs if output_attentions else None
491
  )
492
 
493
 
 
523
  loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
524
 
525
  if not return_dict:
526
+ output = (
527
+ logits,
528
+ *([contextualized_embeddings] if output_hidden_states else []),
529
+ *([attention_probs] if output_attentions else [])
530
+ )
531
  return ((loss,) + output) if loss is not None else output
532
 
533
  return TokenClassifierOutput(
534
  loss=loss,
535
  logits=logits,
536
+ hidden_states=contextualized_embeddings if output_hidden_states else None,
537
+ attentions=attention_probs if output_attentions else None
538
  )
539
 
540
 
 
588
  total_loss = (start_loss + end_loss) / 2
589
 
590
  if not return_dict:
591
+ output = (
592
+ start_logits,
593
+ end_logits,
594
+ *([contextualized_embeddings] if output_hidden_states else []),
595
+ *([attention_probs] if output_attentions else [])
596
+ )
597
  return ((total_loss,) + output) if total_loss is not None else output
598
 
599
  return QuestionAnsweringModelOutput(
600
  loss=total_loss,
601
  start_logits=start_logits,
602
  end_logits=end_logits,
603
+ hidden_states=contextualized_embeddings if output_hidden_states else None,
604
+ attentions=attention_probs if output_attentions else None
605
  )
606
 
607
 
 
622
  token_type_ids: Optional[torch.Tensor] = None,
623
  position_ids: Optional[torch.Tensor] = None,
624
  labels: Optional[torch.Tensor] = None,
625
+ output_attentions: Optional[bool] = None,
626
+ output_hidden_states: Optional[bool] = None,
627
+ return_dict: Optional[bool] = None
628
  ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
629
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
630
  num_choices = input_ids.shape[1]
 
642
  loss = loss_fct(reshaped_logits, labels)
643
 
644
  if not return_dict:
645
+ output = (
646
+ reshaped_logits,
647
+ *([contextualized_embeddings] if output_hidden_states else []),
648
+ *([attention_probs] if output_attentions else [])
649
+ )
650
  return ((loss,) + output) if loss is not None else output
651
 
652
  return MultipleChoiceModelOutput(
653
  loss=loss,
654
  logits=reshaped_logits,
655
+ hidden_states=contextualized_embeddings if output_hidden_states else None,
656
+ attentions=attention_probs if output_attentions else None
657
  )