ccdv commited on
Commit
0390bcc
1 Parent(s): 24969bc

gradient checkpoint + cleanup

Browse files
Files changed (1) hide show
  1. modeling_lsg_bart.py +16 -453
modeling_lsg_bart.py CHANGED
@@ -81,51 +81,6 @@ class LSGBartConfig(BartConfig):
81
  assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
82
 
83
 
84
- def shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id):
85
- """
86
- Shift input ids one token to the right.
87
- """
88
- shifted_input_ids = input_ids.new_zeros(input_ids.shape)
89
- shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
90
- shifted_input_ids[:, 0] = decoder_start_token_id
91
-
92
- if pad_token_id is None:
93
- raise ValueError("self.model.config.pad_token_id has to be defined.")
94
- # replace possible -100 values in labels by `pad_token_id`
95
- shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
96
-
97
- return shifted_input_ids
98
-
99
-
100
- def _make_causal_mask(input_ids_shape, dtype, past_key_values_length=0):
101
- """
102
- Make causal mask used for bi-directional self-attention.
103
- """
104
- bsz, tgt_len = input_ids_shape
105
- mask = torch.full((tgt_len, tgt_len), float("-inf"))
106
- mask_cond = torch.arange(mask.size(-1))
107
- mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
108
- mask = mask.to(dtype)
109
-
110
- if past_key_values_length > 0:
111
- mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1)
112
- return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
113
-
114
-
115
- def _expand_mask(mask, dtype, tgt_len=None):
116
- """
117
- Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
118
- """
119
- bsz, src_len = mask.size()
120
- tgt_len = tgt_len if tgt_len is not None else src_len
121
-
122
- expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
123
-
124
- inverted_mask = 1.0 - expanded_mask
125
-
126
- return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
127
-
128
-
129
  class BaseSelfAttention(nn.Module):
130
 
131
  def __init__(
@@ -663,364 +618,27 @@ class LSGBartEncoderAttention(BaseSelfAttention):
663
  return x.reshape(n, h, -1, chunk_size, d)
664
 
665
 
666
- class LSGBartDecoderAttention(nn.Module):
667
-
668
- """Multi-headed attention from 'Attention Is All You Need' paper"""
669
-
670
- def __init__(
671
- self,
672
- embed_dim,
673
- num_heads,
674
- dropout=0.0,
675
- is_decoder=False,
676
- bias=True,
677
- ):
678
-
679
- super().__init__()
680
- self.embed_dim = embed_dim
681
- self.num_heads = num_heads
682
- self.dropout = dropout
683
- self.head_dim = embed_dim // num_heads
684
-
685
- if (self.head_dim * num_heads) != self.embed_dim:
686
- raise ValueError(
687
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
688
- f" and `num_heads`: {num_heads})."
689
- )
690
- self.scaling = self.head_dim ** -0.5
691
- self.is_decoder = is_decoder
692
-
693
- self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
694
- self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
695
- self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
696
- self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
697
-
698
- def _shape(self, tensor, seq_len, bsz):
699
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
700
-
701
- def forward(
702
- self,
703
- hidden_states,
704
- key_value_states=None,
705
- past_key_value=None,
706
- attention_mask=None,
707
- layer_head_mask=None,
708
- output_attentions=False,
709
- ):
710
-
711
- # if key_value_states are provided this layer is used as a cross-attention layer
712
- # for the decoder
713
- is_cross_attention = key_value_states is not None
714
-
715
- bsz, tgt_len, _ = hidden_states.size()
716
-
717
- # get query proj
718
- query_states = self.q_proj(hidden_states) * self.scaling
719
- # get key, value proj
720
- if is_cross_attention and past_key_value is not None:
721
- # reuse k,v, cross_attentions
722
- key_states = past_key_value[0]
723
- value_states = past_key_value[1]
724
- elif is_cross_attention:
725
- # cross_attentions
726
- key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
727
- value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
728
- elif past_key_value is not None:
729
- # reuse k, v, self_attention
730
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
731
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
732
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
733
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
734
- else:
735
- # self_attention
736
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
737
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
738
-
739
- if self.is_decoder:
740
- # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
741
- # Further calls to cross_attention layer can then reuse all cross-attention
742
- # key/value_states (first "if" case)
743
- # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
744
- # all previous decoder key/value_states. Further calls to uni-directional self-attention
745
- # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
746
- # if encoder bi-directional self-attention `past_key_value` is always `None`
747
- past_key_value = (key_states, value_states)
748
-
749
- proj_shape = (bsz * self.num_heads, -1, self.head_dim)
750
- query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
751
- key_states = key_states.view(*proj_shape)
752
- value_states = value_states.view(*proj_shape)
753
-
754
- src_len = key_states.size(1)
755
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
756
-
757
- if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
758
- raise ValueError(
759
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
760
- )
761
-
762
- if attention_mask is not None:
763
- if attention_mask.size() != (bsz, 1, tgt_len, src_len):
764
- raise ValueError(
765
- f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
766
- )
767
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
768
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
769
-
770
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
771
-
772
- if layer_head_mask is not None:
773
- if layer_head_mask.size() != (self.num_heads,):
774
- raise ValueError(
775
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
776
- )
777
- attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
778
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
779
-
780
- if output_attentions:
781
- # this operation is a bit awkward, but it's required to
782
- # make sure that attn_weights keeps its gradient.
783
- # In order to do so, attn_weights have to be reshaped
784
- # twice and have to be reused in the following
785
- attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
786
- attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
787
- else:
788
- attn_weights_reshaped = None
789
-
790
- attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
791
-
792
- attn_output = torch.bmm(attn_probs, value_states)
793
-
794
- if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
795
- raise ValueError(
796
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
797
- )
798
-
799
- attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
800
- attn_output = attn_output.transpose(1, 2)
801
-
802
- # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
803
- # partitioned aross GPUs when using tensor-parallelism.
804
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
805
-
806
- attn_output = self.out_proj(attn_output)
807
-
808
- return attn_output, attn_weights_reshaped, past_key_value
809
-
810
-
811
- class LSGBartLearnedPositionalEmbedding(nn.Embedding):
812
- """
813
- This module learns positional embeddings up to a fixed maximum size.
814
- """
815
-
816
- def __init__(self, num_embeddings, embedding_dim):
817
- # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
818
- # and adjust num_embeddings appropriately. Other models don't have this hack
819
- self.offset = 2
820
- super().__init__(num_embeddings + self.offset, embedding_dim)
821
-
822
- def forward(self, input_ids_shape, past_key_values_length=0):
823
-
824
- """`input_ids_shape` is expected to be [bsz x seqlen]."""
825
- bsz, seq_len = input_ids_shape[:2]
826
- positions = torch.arange(
827
- past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
828
- )
829
- return super().forward(positions + self.offset)
830
-
831
-
832
- class LSGBartEncoderLayer(nn.Module):
833
 
834
  def __init__(self, config):
835
 
836
- super().__init__()
837
- self.embed_dim = config.d_model
838
  self.self_attn = LSGBartEncoderAttention(
839
  config=config,
840
  embed_dim=self.embed_dim,
841
  num_heads=config.encoder_attention_heads,
842
  dropout=config.attention_dropout,
843
  )
844
- self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
845
- self.dropout = config.dropout
846
- self.activation_fn = ACT2FN[config.activation_function]
847
- self.activation_dropout = config.activation_dropout
848
- self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
849
- self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
850
- self.final_layer_norm = nn.LayerNorm(self.embed_dim)
851
-
852
- def forward(
853
- self,
854
- hidden_states,
855
- attention_mask,
856
- layer_head_mask,
857
- output_attentions=False,
858
- ):
859
- """
860
- Args:
861
- hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
862
- attention_mask (:obj:`torch.FloatTensor`): attention mask of size
863
- `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
864
- layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
865
- `(encoder_attention_heads,)`.
866
- output_attentions (:obj:`bool`, `optional`):
867
- Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
868
- returned tensors for more detail.
869
- """
870
- residual = hidden_states
871
- hidden_states, attn_weights, _ = self.self_attn(
872
- hidden_states=hidden_states,
873
- attention_mask=attention_mask,
874
- layer_head_mask=layer_head_mask,
875
- output_attentions=output_attentions,
876
- )
877
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
878
- hidden_states = residual + hidden_states
879
- hidden_states = self.self_attn_layer_norm(hidden_states)
880
-
881
- residual = hidden_states
882
- hidden_states = self.activation_fn(self.fc1(hidden_states))
883
- hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
884
- hidden_states = self.fc2(hidden_states)
885
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
886
- hidden_states = residual + hidden_states
887
- hidden_states = self.final_layer_norm(hidden_states)
888
-
889
- if hidden_states.dtype == torch.float16 and (
890
- torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
891
- ):
892
- clamp_value = torch.finfo(hidden_states.dtype).max - 1000
893
- hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
894
-
895
- outputs = (hidden_states,)
896
-
897
- if output_attentions:
898
- outputs += (attn_weights,)
899
-
900
- return outputs
901
 
902
 
903
- class LSGBartDecoderLayer(nn.Module):
904
 
905
  def __init__(self, config):
906
 
907
- super().__init__()
908
- self.embed_dim = config.d_model
909
-
910
- self.self_attn = LSGBartDecoderAttention(
911
- embed_dim=self.embed_dim,
912
- num_heads=config.decoder_attention_heads,
913
- dropout=config.attention_dropout,
914
- is_decoder=True,
915
- )
916
- self.dropout = config.dropout
917
- self.activation_fn = ACT2FN[config.activation_function]
918
- self.activation_dropout = config.activation_dropout
919
-
920
- self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
921
- self.encoder_attn = LSGBartDecoderAttention(
922
- self.embed_dim,
923
- config.decoder_attention_heads,
924
- dropout=config.attention_dropout,
925
- is_decoder=True,
926
- )
927
- self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
928
- self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
929
- self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
930
- self.final_layer_norm = nn.LayerNorm(self.embed_dim)
931
-
932
- def forward(
933
- self,
934
- hidden_states,
935
- attention_mask=None,
936
- encoder_hidden_states=None,
937
- encoder_attention_mask=None,
938
- layer_head_mask=None,
939
- cross_attn_layer_head_mask=None,
940
- past_key_value=None,
941
- output_attentions=False,
942
- use_cache=True,
943
- ):
944
- """
945
- Args:
946
- hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
947
- attention_mask (:obj:`torch.FloatTensor`): attention mask of size
948
- `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
949
- encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
950
- encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
951
- `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
952
- layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
953
- `(encoder_attention_heads,)`.
954
- cross_attn_layer_head_mask (:obj:`torch.FloatTensor`): mask for cross-attention heads in a given layer of
955
- size `(decoder_attention_heads,)`.
956
- past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
957
- output_attentions (:obj:`bool`, `optional`):
958
- Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
959
- returned tensors for more detail.
960
- """
961
- residual = hidden_states
962
-
963
- # Self Attention
964
- # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
965
- self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
966
- # add present self-attn cache to positions 1,2 of present_key_value tuple
967
-
968
- hidden_states, self_attn_weights, present_key_value = self.self_attn(
969
- hidden_states=hidden_states,
970
- past_key_value=self_attn_past_key_value,
971
- attention_mask=attention_mask,
972
- layer_head_mask=layer_head_mask,
973
- output_attentions=output_attentions,
974
- )
975
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
976
- hidden_states = residual + hidden_states
977
- hidden_states = self.self_attn_layer_norm(hidden_states)
978
-
979
- # Cross-Attention Block
980
- cross_attn_present_key_value = None
981
- cross_attn_weights = None
982
- if encoder_hidden_states is not None:
983
- residual = hidden_states
984
-
985
- # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
986
- cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
987
-
988
- hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
989
- hidden_states=hidden_states,
990
- key_value_states=encoder_hidden_states,
991
- attention_mask=encoder_attention_mask,
992
- layer_head_mask=cross_attn_layer_head_mask,
993
- past_key_value=cross_attn_past_key_value,
994
- output_attentions=output_attentions,
995
- )
996
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
997
- hidden_states = residual + hidden_states
998
- hidden_states = self.encoder_attn_layer_norm(hidden_states)
999
-
1000
- # add cross-attn to positions 3,4 of present_key_value tuple
1001
- present_key_value = present_key_value + cross_attn_present_key_value
1002
-
1003
- # Fully Connected
1004
- residual = hidden_states
1005
- hidden_states = self.activation_fn(self.fc1(hidden_states))
1006
- hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
1007
- hidden_states = self.fc2(hidden_states)
1008
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
1009
- hidden_states = residual + hidden_states
1010
- hidden_states = self.final_layer_norm(hidden_states)
1011
-
1012
- outputs = (hidden_states,)
1013
-
1014
- if output_attentions:
1015
- outputs += (self_attn_weights, cross_attn_weights)
1016
-
1017
- if use_cache:
1018
- outputs += (present_key_value,)
1019
-
1020
- return outputs
1021
-
1022
 
1023
- class LSGBartClassificationHead(nn.Module):
1024
  """Head for sentence-level classification tasks."""
1025
 
1026
  def __init__(
@@ -1031,55 +649,18 @@ class LSGBartClassificationHead(nn.Module):
1031
  pooler_dropout,
1032
  ):
1033
 
1034
- super().__init__()
1035
- self.dense = nn.Linear(input_dim, inner_dim)
1036
- self.dropout = nn.Dropout(p=pooler_dropout)
1037
- self.out_proj = nn.Linear(inner_dim, num_classes)
1038
-
1039
- def forward(self, hidden_states):
1040
-
1041
- hidden_states = self.dropout(hidden_states)
1042
- hidden_states = self.dense(hidden_states)
1043
- hidden_states = torch.tanh(hidden_states)
1044
- hidden_states = self.dropout(hidden_states)
1045
- hidden_states = self.out_proj(hidden_states)
1046
- return hidden_states
1047
 
1048
 
1049
- class LSGBartPretrainedModel(PreTrainedModel):
1050
 
1051
  config_class = LSGBartConfig
1052
- base_model_prefix = "model"
1053
- supports_gradient_checkpointing = True
1054
- _keys_to_ignore_on_load_unexpected = [r"encoder\.version", r"decoder\.version"]
1055
-
1056
- def _init_weights(self, module):
1057
-
1058
- std = self.config.init_std
1059
- if isinstance(module, nn.Linear):
1060
- module.weight.data.normal_(mean=0.0, std=std)
1061
- if module.bias is not None:
1062
- module.bias.data.zero_()
1063
- elif isinstance(module, nn.Embedding):
1064
- module.weight.data.normal_(mean=0.0, std=std)
1065
- if module.padding_idx is not None:
1066
- module.weight.data[module.padding_idx].zero_()
1067
 
1068
  def _set_gradient_checkpointing(self, module, value=False):
1069
 
1070
- if isinstance(module, (LSGBartDecoder, LSGBartEncoder)):
1071
  module.gradient_checkpointing = value
1072
 
1073
- @property
1074
- def dummy_inputs(self):
1075
- pad_token = self.config.pad_token_id
1076
- input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
1077
- dummy_inputs = {
1078
- "attention_mask": input_ids.ne(pad_token),
1079
- "input_ids": input_ids,
1080
- }
1081
- return dummy_inputs
1082
-
1083
 
1084
  class PretrainedLSGBartModel(LSGBartPretrainedModel):
1085
 
@@ -1090,7 +671,7 @@ class PretrainedLSGBartModel(LSGBartPretrainedModel):
1090
  )
1091
 
1092
 
1093
- class LSGBartEncoder(LSGBartPretrainedModel):
1094
  """
1095
  Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
1096
  :class:`BartEncoderLayer`.
@@ -1115,7 +696,7 @@ class LSGBartEncoder(LSGBartPretrainedModel):
1115
  else:
1116
  self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
1117
 
1118
- self.embed_positions = LSGBartLearnedPositionalEmbedding(
1119
  config.max_position_embeddings,
1120
  embed_dim,
1121
  )
@@ -1140,12 +721,6 @@ class LSGBartEncoder(LSGBartPretrainedModel):
1140
  # Initialize weights and apply final processing
1141
  self.post_init()
1142
 
1143
- def get_input_embeddings(self):
1144
- return self.embed_tokens
1145
-
1146
- def set_input_embeddings(self, value):
1147
- self.embed_tokens = value
1148
-
1149
  def forward(self,
1150
  input_ids=None,
1151
  attention_mask=None,
@@ -1335,7 +910,7 @@ class LSGBartDecoder(BartDecoder, LSGBartPretrainedModel):
1335
  else:
1336
  self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
1337
 
1338
- self.embed_positions = LSGBartLearnedPositionalEmbedding(
1339
  config.max_position_embeddings,
1340
  config.d_model,
1341
  )
@@ -1348,36 +923,24 @@ class LSGBartDecoder(BartDecoder, LSGBartPretrainedModel):
1348
  self.post_init()
1349
 
1350
 
1351
- class LSGBartModel(LSGBartPretrainedModel):
1352
 
1353
  def __init__(self, config):
1354
 
1355
- super().__init__(config)
1356
 
1357
  padding_idx, vocab_size = config.pad_token_id, config.vocab_size
1358
  self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
 
1359
  self.pass_global_tokens_to_decoder = config.pass_global_tokens_to_decoder
1360
  self.num_global_tokens = config.num_global_tokens
 
1361
  self.encoder = LSGBartEncoder(config, self.shared)
1362
  self.decoder = LSGBartDecoder(config, self.shared)
1363
 
1364
  # Initialize weights and apply final processing
1365
  self.post_init()
1366
 
1367
- def get_input_embeddings(self):
1368
- return self.shared
1369
-
1370
- def set_input_embeddings(self, value):
1371
- self.shared = value
1372
- self.encoder.embed_tokens = self.shared
1373
- self.decoder.embed_tokens = self.shared
1374
-
1375
- def get_encoder(self):
1376
- return self.encoder
1377
-
1378
- def get_decoder(self):
1379
- return self.decoder
1380
-
1381
  def forward(
1382
  self,
1383
  input_ids=None,
81
  assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
82
 
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  class BaseSelfAttention(nn.Module):
85
 
86
  def __init__(
618
  return x.reshape(n, h, -1, chunk_size, d)
619
 
620
 
621
+ class LSGBartEncoderLayer(BartEncoderLayer):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
622
 
623
  def __init__(self, config):
624
 
625
+ super().__init__(config)
 
626
  self.self_attn = LSGBartEncoderAttention(
627
  config=config,
628
  embed_dim=self.embed_dim,
629
  num_heads=config.encoder_attention_heads,
630
  dropout=config.attention_dropout,
631
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
632
 
633
 
634
+ class LSGBartDecoderLayer(BartDecoderLayer):
635
 
636
  def __init__(self, config):
637
 
638
+ super().__init__(config)
639
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
640
 
641
+ class LSGBartClassificationHead(BartClassificationHead):
642
  """Head for sentence-level classification tasks."""
643
 
644
  def __init__(
649
  pooler_dropout,
650
  ):
651
 
652
+ super().__init__(input_dim, inner_dim, num_classes, pooler_dropout)
 
 
 
 
 
 
 
 
 
 
 
 
653
 
654
 
655
+ class LSGBartPretrainedModel(BartPretrainedModel):
656
 
657
  config_class = LSGBartConfig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
658
 
659
  def _set_gradient_checkpointing(self, module, value=False):
660
 
661
+ if isinstance(module, (BartDecoder, BartEncoder, LSGBartDecoder, LSGBartEncoder)):
662
  module.gradient_checkpointing = value
663
 
 
 
 
 
 
 
 
 
 
 
664
 
665
  class PretrainedLSGBartModel(LSGBartPretrainedModel):
666
 
671
  )
672
 
673
 
674
+ class LSGBartEncoder(LSGBartPretrainedModel, BartEncoder):
675
  """
676
  Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
677
  :class:`BartEncoderLayer`.
696
  else:
697
  self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
698
 
699
+ self.embed_positions = BartLearnedPositionalEmbedding(
700
  config.max_position_embeddings,
701
  embed_dim,
702
  )
721
  # Initialize weights and apply final processing
722
  self.post_init()
723
 
 
 
 
 
 
 
724
  def forward(self,
725
  input_ids=None,
726
  attention_mask=None,
910
  else:
911
  self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
912
 
913
+ self.embed_positions = BartLearnedPositionalEmbedding(
914
  config.max_position_embeddings,
915
  config.d_model,
916
  )
923
  self.post_init()
924
 
925
 
926
+ class LSGBartModel(LSGBartPretrainedModel, BartModel):
927
 
928
  def __init__(self, config):
929
 
930
+ LSGBartPretrainedModel.__init__(self, config)
931
 
932
  padding_idx, vocab_size = config.pad_token_id, config.vocab_size
933
  self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
934
+
935
  self.pass_global_tokens_to_decoder = config.pass_global_tokens_to_decoder
936
  self.num_global_tokens = config.num_global_tokens
937
+
938
  self.encoder = LSGBartEncoder(config, self.shared)
939
  self.decoder = LSGBartDecoder(config, self.shared)
940
 
941
  # Initialize weights and apply final processing
942
  self.post_init()
943
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
944
  def forward(
945
  self,
946
  input_ids=None,