ccdv commited on
Commit
2a20a92
1 Parent(s): 01d5c90

small fix with torch.finfo

Browse files
Files changed (1) hide show
  1. modeling_lsg_pegasus.py +26 -65
modeling_lsg_pegasus.py CHANGED
@@ -13,7 +13,7 @@ AUTO_MAP = {
13
 
14
  class LSGPegasusConfig(PegasusConfig):
15
  """
16
- This class overrides :class:`~transformers.RobertaConfig`. Please check the superclass for the appropriate
17
  documentation alongside usage examples.
18
  """
19
 
@@ -55,7 +55,8 @@ class LSGPegasusConfig(PegasusConfig):
55
 
56
  if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride"]:
57
  logger.warning(
58
- "[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride'], setting sparsity_type=None, computation will skip sparse attention")
 
59
  self.sparsity_type = None
60
 
61
  if self.sparsity_type in ["stride", "block_stride"]:
@@ -63,7 +64,7 @@ class LSGPegasusConfig(PegasusConfig):
63
  logger.warning(
64
  "[WARNING CONFIG]: sparsity_factor > encoder_attention_heads is not recommended for stride/block_stride sparsity"
65
  )
66
-
67
  if self.num_global_tokens < 1:
68
  logger.warning(
69
  "[WARNING CONFIG]: num_global_tokens < 1 is not compatible, setting num_global_tokens=1"
@@ -71,13 +72,23 @@ class LSGPegasusConfig(PegasusConfig):
71
  self.num_global_tokens = 1
72
  elif self.num_global_tokens > 512:
73
  logger.warning(
74
- "[WARNING CONFIG]: num_global_tokens > 512 is not compatible, setting num_global_tokens=512"
75
  )
76
  self.num_global_tokens = 512
77
-
78
  if self.sparsity_factor > 0:
79
  assert self.block_size % self.sparsity_factor == 0, "[ERROR CONFIG]: block_size must be divisible by sparsity_factor"
80
  assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
 
 
 
 
 
 
 
 
 
 
81
 
82
 
83
  class BaseSelfAttention(nn.Module):
@@ -422,7 +433,8 @@ class LSGPegasusEncoderAttention(BaseSelfAttention):
422
  keys = keys.sum(dim=-2) / (mask + 1e-6)
423
  values = values.sum(dim=-2) / (mask + 1e-6)
424
 
425
- mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
 
426
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
427
 
428
  def get_sparse_tokens_with_stride(self, keys, values, mask):
@@ -487,7 +499,8 @@ class LSGPegasusEncoderAttention(BaseSelfAttention):
487
  keys /= mask + 1e-8
488
  values /= mask + 1e-8
489
 
490
- mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
 
491
 
492
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
493
 
@@ -556,8 +569,6 @@ class LSGPegasusEncoderAttention(BaseSelfAttention):
556
  attention_mask=attention_mask
557
  )
558
 
559
- if head_mask is not None:
560
- context_layer = context_layer * head_mask[:, :, :1, :1]
561
  return self.reshape_output(context_layer)
562
 
563
  # Split input into global tokens and other tokens
@@ -605,8 +616,6 @@ class LSGPegasusEncoderAttention(BaseSelfAttention):
605
 
606
  # Merge global and local-sparse tokens
607
  context_layer = torch.cat([bos, context_layer], dim=-2)
608
- if head_mask is not None:
609
- context_layer = context_layer * head_mask[:, :, :1, :1]
610
  context_layer = self.reshape_output(context_layer)
611
 
612
  return context_layer
@@ -665,21 +674,13 @@ class LSGPegasusEncoderLayer(PegasusEncoderLayer):
665
  dropout=config.attention_dropout,
666
  )
667
 
668
-
669
- # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Pegasus
670
- class LSGPegasusDecoderLayer(PegasusDecoderLayer):
671
-
672
- def __init__(self, config: LSGPegasusConfig):
673
-
674
- super().__init__(config)
675
 
676
-
677
  class LSGPegasusPreTrainedModel(PegasusPreTrainedModel):
678
 
679
  config_class = LSGPegasusConfig
680
 
681
  def _set_gradient_checkpointing(self, module, value=False):
682
- if isinstance(module, (PegasusDecoder, PegasusEncoder, LSGPegasusDecoder, LSGPegasusEncoder)):
683
  module.gradient_checkpointing = value
684
 
685
 
@@ -922,44 +923,6 @@ class LSGPegasusEncoder(LSGPegasusPreTrainedModel, PegasusEncoder):
922
  )
923
 
924
 
925
- class LSGPegasusDecoder(LSGPegasusPreTrainedModel, PegasusDecoder):
926
- """
927
- Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`PegasusDecoderLayer`
928
- Args:
929
- config: PegasusConfig
930
- embed_tokens (nn.Embedding): output embedding
931
- """
932
-
933
- def __init__(self, config: LSGPegasusConfig, embed_tokens: Optional[nn.Embedding] = None):
934
-
935
- LSGPegasusPreTrainedModel.__init__(self, config)
936
-
937
- self.dropout = config.dropout
938
- self.layerdrop = config.decoder_layerdrop
939
- self.padding_idx = config.pad_token_id
940
- self.max_target_positions = config.max_position_embeddings
941
- self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
942
- self.adaptive = config.adaptive
943
-
944
- if embed_tokens is not None:
945
- self.embed_tokens = embed_tokens
946
- else:
947
- self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
948
-
949
- self.embed_positions = LSGPegasusSinusoidalPositionalEmbedding(
950
- config.max_position_embeddings,
951
- config.d_model,
952
- self.padding_idx,
953
- )
954
- self.layers = nn.ModuleList([LSGPegasusDecoderLayer(config) for _ in range(config.decoder_layers)])
955
- self.layer_norm = nn.LayerNorm(config.d_model)
956
-
957
- self.gradient_checkpointing = False
958
-
959
- # Initialize weights and apply final processing
960
- self.post_init()
961
-
962
-
963
  class LSGPegasusModel(LSGPegasusPreTrainedModel, PegasusModel):
964
 
965
  def __init__(self, config: LSGPegasusConfig):
@@ -971,7 +934,7 @@ class LSGPegasusModel(LSGPegasusPreTrainedModel, PegasusModel):
971
  self.pass_global_tokens_to_decoder = config.pass_global_tokens_to_decoder
972
  self.num_global_tokens = config.num_global_tokens
973
  self.encoder = LSGPegasusEncoder(config, self.shared)
974
- self.decoder = LSGPegasusDecoder(config, self.shared)
975
 
976
  # Initialize weights and apply final processing
977
  self.post_init()
@@ -1078,6 +1041,7 @@ class LSGPegasusForConditionalGeneration(LSGPegasusPreTrainedModel, PegasusForCo
1078
  ]
1079
 
1080
  def __init__(self, config: LSGPegasusConfig):
 
1081
  LSGPegasusPreTrainedModel.__init__(self, config)
1082
  self.model = LSGPegasusModel(config)
1083
  self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
@@ -1088,18 +1052,15 @@ class LSGPegasusForConditionalGeneration(LSGPegasusPreTrainedModel, PegasusForCo
1088
 
1089
 
1090
  # Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->Pegasus
1091
- class LSGPegasusDecoderWrapper(LSGPegasusPreTrainedModel):
1092
  """
1093
  This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
1094
  used in combination with the :class:`~transformers.EncoderDecoderModel` framework.
1095
  """
1096
 
1097
  def __init__(self, config):
1098
- super().__init__(config)
1099
- self.decoder = LSGPegasusDecoder(config)
1100
-
1101
- def forward(self, *args, **kwargs):
1102
- return self.decoder(*args, **kwargs)
1103
 
1104
 
1105
  class LSGPegasusForCausalLM(LSGPegasusPreTrainedModel, PegasusForCausalLM):
 
13
 
14
  class LSGPegasusConfig(PegasusConfig):
15
  """
16
+ This class overrides :class:`~transformers.PegasusConfig`. Please check the superclass for the appropriate
17
  documentation alongside usage examples.
18
  """
19
 
 
55
 
56
  if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride"]:
57
  logger.warning(
58
+ "[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride'], \
59
+ setting sparsity_type=None, computation will skip sparse attention")
60
  self.sparsity_type = None
61
 
62
  if self.sparsity_type in ["stride", "block_stride"]:
 
64
  logger.warning(
65
  "[WARNING CONFIG]: sparsity_factor > encoder_attention_heads is not recommended for stride/block_stride sparsity"
66
  )
67
+
68
  if self.num_global_tokens < 1:
69
  logger.warning(
70
  "[WARNING CONFIG]: num_global_tokens < 1 is not compatible, setting num_global_tokens=1"
 
72
  self.num_global_tokens = 1
73
  elif self.num_global_tokens > 512:
74
  logger.warning(
75
+ "[WARNING CONFIG]: num_global_tokens > 512 is not allowed, setting num_global_tokens=512"
76
  )
77
  self.num_global_tokens = 512
78
+
79
  if self.sparsity_factor > 0:
80
  assert self.block_size % self.sparsity_factor == 0, "[ERROR CONFIG]: block_size must be divisible by sparsity_factor"
81
  assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
82
+
83
+ if self.mask_first_token and not pool_with_global:
84
+ logger.warning(
85
+ "[WARNING CONFIG]: pool_with_global==False is not compatible with mask_first_token==True. Setting pool_with_global to True.")
86
+ self.pool_with_global = True
87
+
88
+ if hasattr(self, "position_embedding_type"):
89
+ if self.position_embedding_type != "absolute":
90
+ logger.warning(
91
+ "[WARNING CONFIG]: LSG Attention is not compatible with relative positional embedding and will skip its computation. Set position_embedding_type='absolute' to remove this warning.")
92
 
93
 
94
  class BaseSelfAttention(nn.Module):
 
433
  keys = keys.sum(dim=-2) / (mask + 1e-6)
434
  values = values.sum(dim=-2) / (mask + 1e-6)
435
 
436
+ mask = (1. - mask.clamp(0, 1))
437
+ mask *= torch.finfo(mask.dtype).min
438
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
439
 
440
  def get_sparse_tokens_with_stride(self, keys, values, mask):
 
499
  keys /= mask + 1e-8
500
  values /= mask + 1e-8
501
 
502
+ mask = (1. - mask.clamp(0, 1))
503
+ mask *= torch.finfo(mask.dtype).min
504
 
505
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
506
 
 
569
  attention_mask=attention_mask
570
  )
571
 
 
 
572
  return self.reshape_output(context_layer)
573
 
574
  # Split input into global tokens and other tokens
 
616
 
617
  # Merge global and local-sparse tokens
618
  context_layer = torch.cat([bos, context_layer], dim=-2)
 
 
619
  context_layer = self.reshape_output(context_layer)
620
 
621
  return context_layer
 
674
  dropout=config.attention_dropout,
675
  )
676
 
 
 
 
 
 
 
 
677
 
 
678
  class LSGPegasusPreTrainedModel(PegasusPreTrainedModel):
679
 
680
  config_class = LSGPegasusConfig
681
 
682
  def _set_gradient_checkpointing(self, module, value=False):
683
+ if isinstance(module, (PegasusDecoder, PegasusEncoder, LSGPegasusEncoder)):
684
  module.gradient_checkpointing = value
685
 
686
 
 
923
  )
924
 
925
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
926
  class LSGPegasusModel(LSGPegasusPreTrainedModel, PegasusModel):
927
 
928
  def __init__(self, config: LSGPegasusConfig):
 
934
  self.pass_global_tokens_to_decoder = config.pass_global_tokens_to_decoder
935
  self.num_global_tokens = config.num_global_tokens
936
  self.encoder = LSGPegasusEncoder(config, self.shared)
937
+ self.decoder = PegasusDecoder(config, self.shared)
938
 
939
  # Initialize weights and apply final processing
940
  self.post_init()
 
1041
  ]
1042
 
1043
  def __init__(self, config: LSGPegasusConfig):
1044
+
1045
  LSGPegasusPreTrainedModel.__init__(self, config)
1046
  self.model = LSGPegasusModel(config)
1047
  self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
 
1052
 
1053
 
1054
  # Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->Pegasus
1055
+ class LSGPegasusDecoderWrapper(LSGPegasusPreTrainedModel, PegasusDecoderWrapper):
1056
  """
1057
  This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
1058
  used in combination with the :class:`~transformers.EncoderDecoderModel` framework.
1059
  """
1060
 
1061
  def __init__(self, config):
1062
+ LSGPegasusPreTrainedModel.__init__(self, config)
1063
+ PegasusDecoderWrapper.__init__(self, config)
 
 
 
1064
 
1065
 
1066
  class LSGPegasusForCausalLM(LSGPegasusPreTrainedModel, PegasusForCausalLM):