ccdv commited on
Commit
2bc6c72
1 Parent(s): 203cea7

fix version 4.23

Browse files
Files changed (1) hide show
  1. modeling_lsg_bart.py +18 -91
modeling_lsg_bart.py CHANGED
@@ -57,7 +57,8 @@ class LSGBartConfig(BartConfig):
57
 
58
  if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride"]:
59
  logger.warning(
60
- "[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride'], setting sparsity_type=None, computation will skip sparse attention")
 
61
  self.sparsity_type = None
62
 
63
  if self.sparsity_type in ["stride", "block_stride"]:
@@ -73,7 +74,7 @@ class LSGBartConfig(BartConfig):
73
  self.num_global_tokens = 1
74
  elif self.num_global_tokens > 512:
75
  logger.warning(
76
- "[WARNING CONFIG]: num_global_tokens > 512 is not compatible, setting num_global_tokens=512"
77
  )
78
  self.num_global_tokens = 512
79
 
@@ -81,6 +82,16 @@ class LSGBartConfig(BartConfig):
81
  assert self.block_size % self.sparsity_factor == 0, "[ERROR CONFIG]: block_size must be divisible by sparsity_factor"
82
  assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
83
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  class BaseSelfAttention(nn.Module):
86
 
@@ -557,8 +568,6 @@ class LSGBartEncoderAttention(BaseSelfAttention):
557
  attention_mask=attention_mask
558
  )
559
 
560
- if head_mask is not None:
561
- context_layer = context_layer * head_mask[:, :, :1, :1]
562
  return self.reshape_output(context_layer)
563
 
564
  # Split input into global tokens and other tokens
@@ -606,8 +615,6 @@ class LSGBartEncoderAttention(BaseSelfAttention):
606
 
607
  # Merge global and local-sparse tokens
608
  context_layer = torch.cat([bos, context_layer], dim=-2)
609
- if head_mask is not None:
610
- context_layer = context_layer * head_mask[:, :, :1, :1]
611
  context_layer = self.reshape_output(context_layer)
612
 
613
  return context_layer
@@ -630,35 +637,14 @@ class LSGBartEncoderLayer(BartEncoderLayer):
630
  dropout=config.attention_dropout,
631
  )
632
 
633
-
634
- class LSGBartDecoderLayer(BartDecoderLayer):
635
-
636
- def __init__(self, config):
637
-
638
- super().__init__(config)
639
 
640
-
641
- class LSGBartClassificationHead(BartClassificationHead):
642
- """Head for sentence-level classification tasks."""
643
-
644
- def __init__(
645
- self,
646
- input_dim,
647
- inner_dim,
648
- num_classes,
649
- pooler_dropout,
650
- ):
651
-
652
- super().__init__(input_dim, inner_dim, num_classes, pooler_dropout)
653
-
654
-
655
  class LSGBartPretrainedModel(BartPretrainedModel):
656
 
657
  config_class = LSGBartConfig
658
 
659
  def _set_gradient_checkpointing(self, module, value=False):
660
 
661
- if isinstance(module, (BartDecoder, BartEncoder, LSGBartDecoder, LSGBartEncoder)):
662
  module.gradient_checkpointing = value
663
 
664
 
@@ -818,7 +804,7 @@ class LSGBartEncoder(LSGBartPretrainedModel, BartEncoder):
818
  if inputs_embeds is None:
819
  inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
820
 
821
- embed_pos = self.embed_positions(input_shape)
822
  hidden_states = inputs_embeds + embed_pos
823
 
824
  # Add global tokens
@@ -889,43 +875,6 @@ class LSGBartEncoder(LSGBartPretrainedModel, BartEncoder):
889
  )
890
 
891
 
892
- class LSGBartDecoder(LSGBartPretrainedModel, BartDecoder):
893
- """
894
- Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`LSGBartDecoderLayer`
895
- Args:
896
- config: BartConfig
897
- embed_tokens (nn.Embedding): output embedding
898
- """
899
-
900
- def __init__(self, config, embed_tokens=None):
901
-
902
- LSGBartPretrainedModel.__init__(self, config)
903
-
904
- self.dropout = config.dropout
905
- self.layerdrop = config.decoder_layerdrop
906
- self.padding_idx = config.pad_token_id
907
- self.max_target_positions = config.max_position_embeddings
908
- self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
909
- self.adaptive = config.adaptive
910
-
911
- if embed_tokens is not None:
912
- self.embed_tokens = embed_tokens
913
- else:
914
- self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
915
-
916
- self.embed_positions = BartLearnedPositionalEmbedding(
917
- config.max_position_embeddings,
918
- config.d_model,
919
- )
920
- self.layers = nn.ModuleList([LSGBartDecoderLayer(config) for _ in range(config.decoder_layers)])
921
- self.layernorm_embedding = nn.LayerNorm(config.d_model)
922
-
923
- self.gradient_checkpointing = False
924
-
925
- # Initialize weights and apply final processing
926
- self.post_init()
927
-
928
-
929
  class LSGBartModel(LSGBartPretrainedModel, BartModel):
930
 
931
  def __init__(self, config):
@@ -939,7 +888,7 @@ class LSGBartModel(LSGBartPretrainedModel, BartModel):
939
  self.num_global_tokens = config.num_global_tokens
940
 
941
  self.encoder = LSGBartEncoder(config, self.shared)
942
- self.decoder = LSGBartDecoder(config, self.shared)
943
 
944
  # Initialize weights and apply final processing
945
  self.post_init()
@@ -1052,7 +1001,7 @@ class LSGBartForSequenceClassification(LSGBartPretrainedModel, BartForSequenceCl
1052
 
1053
  LSGBartPretrainedModel.__init__(self, config, **kwargs)
1054
  self.model = LSGBartModel(config)
1055
- self.classification_head = LSGBartClassificationHead(
1056
  config.d_model,
1057
  config.d_model,
1058
  config.num_labels,
@@ -1077,34 +1026,12 @@ class LSGBartForQuestionAnswering(LSGBartPretrainedModel, BartForQuestionAnsweri
1077
  self.model._init_weights(self.qa_outputs)
1078
 
1079
 
1080
- class LSGBartDecoderWrapper(LSGBartPretrainedModel):
1081
- """
1082
- This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
1083
- used in combination with the :class:`~transformers.EncoderDecoderModel` framework.
1084
- """
1085
-
1086
- def __init__(self, config: LSGBartConfig):
1087
- super().__init__(config)
1088
- self.decoder = LSGBartDecoder(config)
1089
-
1090
- def forward(self, *args, **kwargs):
1091
- return self.decoder(*args, **kwargs)
1092
-
1093
-
1094
  class LSGBartForCausalLM(LSGBartPretrainedModel, BartForCausalLM):
1095
 
1096
  def __init__(self, config: LSGBartConfig):
1097
 
1098
- config = copy.deepcopy(config)
1099
- config.is_decoder = True
1100
- config.is_encoder_decoder = False
1101
  LSGBartPretrainedModel.__init__(self, config)
1102
- self.model = LSGBartDecoderWrapper(config)
1103
-
1104
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1105
-
1106
- # Initialize weights and apply final processing
1107
- self.post_init()
1108
 
1109
 
1110
  def str_to_class(classname):
 
57
 
58
  if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride"]:
59
  logger.warning(
60
+ "[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride'], \
61
+ setting sparsity_type=None, computation will skip sparse attention")
62
  self.sparsity_type = None
63
 
64
  if self.sparsity_type in ["stride", "block_stride"]:
 
74
  self.num_global_tokens = 1
75
  elif self.num_global_tokens > 512:
76
  logger.warning(
77
+ "[WARNING CONFIG]: num_global_tokens > 512 is not allowed, setting num_global_tokens=512"
78
  )
79
  self.num_global_tokens = 512
80
 
 
82
  assert self.block_size % self.sparsity_factor == 0, "[ERROR CONFIG]: block_size must be divisible by sparsity_factor"
83
  assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
84
 
85
+ if self.mask_first_token and not pool_with_global:
86
+ logger.warning(
87
+ "[WARNING CONFIG]: pool_with_global==False is not compatible with mask_first_token==True. Setting pool_with_global to True.")
88
+ self.pool_with_global = True
89
+
90
+ if hasattr(self, "position_embedding_type"):
91
+ if self.position_embedding_type != "absolute":
92
+ logger.warning(
93
+ "[WARNING CONFIG]: LSG Attention is not compatible with relative positional embedding and will skip its computation. Set position_embedding_type='absolute' to remove this warning.")
94
+
95
 
96
  class BaseSelfAttention(nn.Module):
97
 
 
568
  attention_mask=attention_mask
569
  )
570
 
 
 
571
  return self.reshape_output(context_layer)
572
 
573
  # Split input into global tokens and other tokens
 
615
 
616
  # Merge global and local-sparse tokens
617
  context_layer = torch.cat([bos, context_layer], dim=-2)
 
 
618
  context_layer = self.reshape_output(context_layer)
619
 
620
  return context_layer
 
637
  dropout=config.attention_dropout,
638
  )
639
 
 
 
 
 
 
 
640
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
641
  class LSGBartPretrainedModel(BartPretrainedModel):
642
 
643
  config_class = LSGBartConfig
644
 
645
  def _set_gradient_checkpointing(self, module, value=False):
646
 
647
+ if isinstance(module, (BartDecoder, BartEncoder, LSGBartEncoder)):
648
  module.gradient_checkpointing = value
649
 
650
 
 
804
  if inputs_embeds is None:
805
  inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
806
 
807
+ embed_pos = self.embed_positions(inputs_embeds)
808
  hidden_states = inputs_embeds + embed_pos
809
 
810
  # Add global tokens
 
875
  )
876
 
877
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
878
  class LSGBartModel(LSGBartPretrainedModel, BartModel):
879
 
880
  def __init__(self, config):
 
888
  self.num_global_tokens = config.num_global_tokens
889
 
890
  self.encoder = LSGBartEncoder(config, self.shared)
891
+ self.decoder = BartDecoder(config, self.shared)
892
 
893
  # Initialize weights and apply final processing
894
  self.post_init()
 
1001
 
1002
  LSGBartPretrainedModel.__init__(self, config, **kwargs)
1003
  self.model = LSGBartModel(config)
1004
+ self.classification_head = BartClassificationHead(
1005
  config.d_model,
1006
  config.d_model,
1007
  config.num_labels,
 
1026
  self.model._init_weights(self.qa_outputs)
1027
 
1028
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1029
  class LSGBartForCausalLM(LSGBartPretrainedModel, BartForCausalLM):
1030
 
1031
  def __init__(self, config: LSGBartConfig):
1032
 
 
 
 
1033
  LSGBartPretrainedModel.__init__(self, config)
1034
+ BartForCausalLM.__init__(self, config)
 
 
 
 
 
1035
 
1036
 
1037
  def str_to_class(classname):