robinzixuan commited on
Commit
abe9faf
1 Parent(s): 75986f3

Update modeling_bert.py

Browse files
Files changed (1) hide show
  1. modeling_bert.py +3 -151
modeling_bert.py CHANGED
@@ -6,6 +6,7 @@
6
  # you may not use this file except in compliance with the License.
7
 
8
  # You may obtain a copy of the License at
 
9
  #
10
  # http://www.apache.org/licenses/LICENSE-2.0
11
  #
@@ -411,156 +412,7 @@ class BertSelfAttention(nn.Module):
411
  return outputs
412
 
413
 
414
- class BertOutEffHop(nn.Module):
415
- def __init__(self, config, position_embedding_type=None):
416
- super().__init__()
417
- if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
418
- raise ValueError(
419
- f'''The hidden size ({
420
- config.hidden_size}) is not a multiple of the number of attention '''
421
- f"heads ({config.num_attention_heads})"
422
- )
423
-
424
- self.num_attention_heads = config.num_attention_heads
425
- self.attention_head_size = int(
426
- config.hidden_size / config.num_attention_heads)
427
- self.all_head_size = self.num_attention_heads * self.attention_head_size
428
-
429
- self.query = nn.Linear(config.hidden_size, self.all_head_size)
430
- self.key = nn.Linear(config.hidden_size, self.all_head_size)
431
- self.value = nn.Linear(config.hidden_size, self.all_head_size)
432
-
433
- self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
434
- self.position_embedding_type = position_embedding_type or getattr(
435
- config, "position_embedding_type", "absolute"
436
- )
437
- if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
438
- self.max_position_embeddings = config.max_position_embeddings
439
- self.distance_embedding = nn.Embedding(
440
- 2 * config.max_position_embeddings - 1, self.attention_head_size)
441
-
442
- self.is_decoder = config.is_decoder
443
 
444
- def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
445
- new_x_shape = x.size()[
446
- :-1] + (self.num_attention_heads, self.attention_head_size)
447
- x = x.view(new_x_shape)
448
- return x.permute(0, 2, 1, 3)
449
-
450
- def forward(
451
- self,
452
- hidden_states: torch.Tensor,
453
- attention_mask: Optional[torch.FloatTensor] = None,
454
- head_mask: Optional[torch.FloatTensor] = None,
455
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
456
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
457
- past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
458
- output_attentions: Optional[bool] = False,
459
- ) -> Tuple[torch.Tensor]:
460
- mixed_query_layer = self.query(hidden_states)
461
-
462
- # If this is instantiated as a cross-attention module, the keys
463
- # and values come from an encoder; the attention mask needs to be
464
- # such that the encoder's padding tokens are not attended to.
465
- is_cross_attention = encoder_hidden_states is not None
466
-
467
- if is_cross_attention and past_key_value is not None:
468
- # reuse k,v, cross_attentions
469
- key_layer = past_key_value[0]
470
- value_layer = past_key_value[1]
471
- attention_mask = encoder_attention_mask
472
- elif is_cross_attention:
473
- key_layer = self.transpose_for_scores(
474
- self.key(encoder_hidden_states))
475
- value_layer = self.transpose_for_scores(
476
- self.value(encoder_hidden_states))
477
- attention_mask = encoder_attention_mask
478
- elif past_key_value is not None:
479
- key_layer = self.transpose_for_scores(self.key(hidden_states))
480
- value_layer = self.transpose_for_scores(self.value(hidden_states))
481
- key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
482
- value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
483
- else:
484
- key_layer = self.transpose_for_scores(self.key(hidden_states))
485
- value_layer = self.transpose_for_scores(self.value(hidden_states))
486
-
487
- query_layer = self.transpose_for_scores(mixed_query_layer)
488
-
489
- use_cache = past_key_value is not None
490
- if self.is_decoder:
491
- # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
492
- # Further calls to cross_attention layer can then reuse all cross-attention
493
- # key/value_states (first "if" case)
494
- # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
495
- # all previous decoder key/value_states. Further calls to uni-directional self-attention
496
- # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
497
- # if encoder bi-directional self-attention `past_key_value` is always `None`
498
- past_key_value = (key_layer, value_layer)
499
-
500
- # Take the dot product between "query" and "key" to get the raw attention scores.
501
- attention_scores = torch.matmul(
502
- query_layer, key_layer.transpose(-1, -2))
503
-
504
- if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
505
- query_length, key_length = query_layer.shape[2], key_layer.shape[2]
506
- if use_cache:
507
- position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
508
- -1, 1
509
- )
510
- else:
511
- position_ids_l = torch.arange(
512
- query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
513
- position_ids_r = torch.arange(
514
- key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
515
- distance = position_ids_l - position_ids_r
516
-
517
- positional_embedding = self.distance_embedding(
518
- distance + self.max_position_embeddings - 1)
519
- positional_embedding = positional_embedding.to(
520
- dtype=query_layer.dtype) # fp16 compatibility
521
-
522
- if self.position_embedding_type == "relative_key":
523
- relative_position_scores = torch.einsum(
524
- "bhld,lrd->bhlr", query_layer, positional_embedding)
525
- attention_scores = attention_scores + relative_position_scores
526
- elif self.position_embedding_type == "relative_key_query":
527
- relative_position_scores_query = torch.einsum(
528
- "bhld,lrd->bhlr", query_layer, positional_embedding)
529
- relative_position_scores_key = torch.einsum(
530
- "bhrd,lrd->bhlr", key_layer, positional_embedding)
531
- attention_scores = attention_scores + \
532
- relative_position_scores_query + relative_position_scores_key
533
-
534
- attention_scores = attention_scores / \
535
- math.sqrt(self.attention_head_size)
536
- if attention_mask is not None:
537
- # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
538
- attention_scores = attention_scores + attention_mask
539
-
540
- # Normalize the attention scores to probabilities.
541
- attention_probs = softmax_1(attention_scores, dim=-1)
542
- print(softmax_1)
543
- # This is actually dropping out entire tokens to attend to, which might
544
- # seem a bit unusual, but is taken from the original Transformer paper.
545
- attention_probs = self.dropout(attention_probs)
546
-
547
- # Mask heads if we want to
548
- if head_mask is not None:
549
- attention_probs = attention_probs * head_mask
550
-
551
- context_layer = torch.matmul(attention_probs, value_layer)
552
-
553
- context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
554
- new_context_layer_shape = context_layer.size()[
555
- :-2] + (self.all_head_size,)
556
- context_layer = context_layer.view(new_context_layer_shape)
557
-
558
- outputs = (context_layer, attention_probs) if output_attentions else (
559
- context_layer,)
560
-
561
- if self.is_decoder:
562
- outputs = outputs + (past_key_value,)
563
- return outputs
564
 
565
 
566
  class BertSdpaSelfAttention(BertSelfAttention):
@@ -684,14 +536,14 @@ class BertSelfOutput(nn.Module):
684
  BERT_SELF_ATTENTION_CLASSES = {
685
  "eager": BertSelfAttention,
686
  "sdpa": BertSdpaSelfAttention,
687
- "OutEffHop": BertOutEffHop,
688
  }
689
 
690
 
691
  class BertAttention(nn.Module):
692
  def __init__(self, config, position_embedding_type=None):
693
  super().__init__()
694
- self.self = BERT_SELF_ATTENTION_CLASSES[config._attn_implementation](
695
  config, position_embedding_type=position_embedding_type
696
  )
697
  self.output = BertSelfOutput(config)
 
6
  # you may not use this file except in compliance with the License.
7
 
8
  # You may obtain a copy of the License at
9
+
10
  #
11
  # http://www.apache.org/licenses/LICENSE-2.0
12
  #
 
412
  return outputs
413
 
414
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
416
 
417
 
418
  class BertSdpaSelfAttention(BertSelfAttention):
 
536
  BERT_SELF_ATTENTION_CLASSES = {
537
  "eager": BertSelfAttention,
538
  "sdpa": BertSdpaSelfAttention,
539
+
540
  }
541
 
542
 
543
  class BertAttention(nn.Module):
544
  def __init__(self, config, position_embedding_type=None):
545
  super().__init__()
546
+ self.self = BertSelfAttention(
547
  config, position_embedding_type=position_embedding_type
548
  )
549
  self.output = BertSelfOutput(config)