gradient checkpoint + cleanup
Browse files- modeling_lsg_bart.py +16 -453
modeling_lsg_bart.py
CHANGED
@@ -81,51 +81,6 @@ class LSGBartConfig(BartConfig):
|
|
81 |
assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
|
82 |
|
83 |
|
84 |
-
def shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id):
|
85 |
-
"""
|
86 |
-
Shift input ids one token to the right.
|
87 |
-
"""
|
88 |
-
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
89 |
-
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
|
90 |
-
shifted_input_ids[:, 0] = decoder_start_token_id
|
91 |
-
|
92 |
-
if pad_token_id is None:
|
93 |
-
raise ValueError("self.model.config.pad_token_id has to be defined.")
|
94 |
-
# replace possible -100 values in labels by `pad_token_id`
|
95 |
-
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
|
96 |
-
|
97 |
-
return shifted_input_ids
|
98 |
-
|
99 |
-
|
100 |
-
def _make_causal_mask(input_ids_shape, dtype, past_key_values_length=0):
|
101 |
-
"""
|
102 |
-
Make causal mask used for bi-directional self-attention.
|
103 |
-
"""
|
104 |
-
bsz, tgt_len = input_ids_shape
|
105 |
-
mask = torch.full((tgt_len, tgt_len), float("-inf"))
|
106 |
-
mask_cond = torch.arange(mask.size(-1))
|
107 |
-
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
108 |
-
mask = mask.to(dtype)
|
109 |
-
|
110 |
-
if past_key_values_length > 0:
|
111 |
-
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1)
|
112 |
-
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
113 |
-
|
114 |
-
|
115 |
-
def _expand_mask(mask, dtype, tgt_len=None):
|
116 |
-
"""
|
117 |
-
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
118 |
-
"""
|
119 |
-
bsz, src_len = mask.size()
|
120 |
-
tgt_len = tgt_len if tgt_len is not None else src_len
|
121 |
-
|
122 |
-
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
123 |
-
|
124 |
-
inverted_mask = 1.0 - expanded_mask
|
125 |
-
|
126 |
-
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
|
127 |
-
|
128 |
-
|
129 |
class BaseSelfAttention(nn.Module):
|
130 |
|
131 |
def __init__(
|
@@ -663,364 +618,27 @@ class LSGBartEncoderAttention(BaseSelfAttention):
|
|
663 |
return x.reshape(n, h, -1, chunk_size, d)
|
664 |
|
665 |
|
666 |
-
class
|
667 |
-
|
668 |
-
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
669 |
-
|
670 |
-
def __init__(
|
671 |
-
self,
|
672 |
-
embed_dim,
|
673 |
-
num_heads,
|
674 |
-
dropout=0.0,
|
675 |
-
is_decoder=False,
|
676 |
-
bias=True,
|
677 |
-
):
|
678 |
-
|
679 |
-
super().__init__()
|
680 |
-
self.embed_dim = embed_dim
|
681 |
-
self.num_heads = num_heads
|
682 |
-
self.dropout = dropout
|
683 |
-
self.head_dim = embed_dim // num_heads
|
684 |
-
|
685 |
-
if (self.head_dim * num_heads) != self.embed_dim:
|
686 |
-
raise ValueError(
|
687 |
-
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
688 |
-
f" and `num_heads`: {num_heads})."
|
689 |
-
)
|
690 |
-
self.scaling = self.head_dim ** -0.5
|
691 |
-
self.is_decoder = is_decoder
|
692 |
-
|
693 |
-
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
694 |
-
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
695 |
-
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
696 |
-
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
697 |
-
|
698 |
-
def _shape(self, tensor, seq_len, bsz):
|
699 |
-
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
700 |
-
|
701 |
-
def forward(
|
702 |
-
self,
|
703 |
-
hidden_states,
|
704 |
-
key_value_states=None,
|
705 |
-
past_key_value=None,
|
706 |
-
attention_mask=None,
|
707 |
-
layer_head_mask=None,
|
708 |
-
output_attentions=False,
|
709 |
-
):
|
710 |
-
|
711 |
-
# if key_value_states are provided this layer is used as a cross-attention layer
|
712 |
-
# for the decoder
|
713 |
-
is_cross_attention = key_value_states is not None
|
714 |
-
|
715 |
-
bsz, tgt_len, _ = hidden_states.size()
|
716 |
-
|
717 |
-
# get query proj
|
718 |
-
query_states = self.q_proj(hidden_states) * self.scaling
|
719 |
-
# get key, value proj
|
720 |
-
if is_cross_attention and past_key_value is not None:
|
721 |
-
# reuse k,v, cross_attentions
|
722 |
-
key_states = past_key_value[0]
|
723 |
-
value_states = past_key_value[1]
|
724 |
-
elif is_cross_attention:
|
725 |
-
# cross_attentions
|
726 |
-
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
727 |
-
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
728 |
-
elif past_key_value is not None:
|
729 |
-
# reuse k, v, self_attention
|
730 |
-
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
731 |
-
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
732 |
-
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
733 |
-
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
734 |
-
else:
|
735 |
-
# self_attention
|
736 |
-
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
737 |
-
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
738 |
-
|
739 |
-
if self.is_decoder:
|
740 |
-
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
741 |
-
# Further calls to cross_attention layer can then reuse all cross-attention
|
742 |
-
# key/value_states (first "if" case)
|
743 |
-
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
744 |
-
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
745 |
-
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
746 |
-
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
747 |
-
past_key_value = (key_states, value_states)
|
748 |
-
|
749 |
-
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
750 |
-
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
751 |
-
key_states = key_states.view(*proj_shape)
|
752 |
-
value_states = value_states.view(*proj_shape)
|
753 |
-
|
754 |
-
src_len = key_states.size(1)
|
755 |
-
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
756 |
-
|
757 |
-
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
758 |
-
raise ValueError(
|
759 |
-
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
|
760 |
-
)
|
761 |
-
|
762 |
-
if attention_mask is not None:
|
763 |
-
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
764 |
-
raise ValueError(
|
765 |
-
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
766 |
-
)
|
767 |
-
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
768 |
-
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
769 |
-
|
770 |
-
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
771 |
-
|
772 |
-
if layer_head_mask is not None:
|
773 |
-
if layer_head_mask.size() != (self.num_heads,):
|
774 |
-
raise ValueError(
|
775 |
-
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
|
776 |
-
)
|
777 |
-
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
778 |
-
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
779 |
-
|
780 |
-
if output_attentions:
|
781 |
-
# this operation is a bit awkward, but it's required to
|
782 |
-
# make sure that attn_weights keeps its gradient.
|
783 |
-
# In order to do so, attn_weights have to be reshaped
|
784 |
-
# twice and have to be reused in the following
|
785 |
-
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
786 |
-
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
787 |
-
else:
|
788 |
-
attn_weights_reshaped = None
|
789 |
-
|
790 |
-
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
791 |
-
|
792 |
-
attn_output = torch.bmm(attn_probs, value_states)
|
793 |
-
|
794 |
-
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
795 |
-
raise ValueError(
|
796 |
-
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
|
797 |
-
)
|
798 |
-
|
799 |
-
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
800 |
-
attn_output = attn_output.transpose(1, 2)
|
801 |
-
|
802 |
-
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
803 |
-
# partitioned aross GPUs when using tensor-parallelism.
|
804 |
-
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
805 |
-
|
806 |
-
attn_output = self.out_proj(attn_output)
|
807 |
-
|
808 |
-
return attn_output, attn_weights_reshaped, past_key_value
|
809 |
-
|
810 |
-
|
811 |
-
class LSGBartLearnedPositionalEmbedding(nn.Embedding):
|
812 |
-
"""
|
813 |
-
This module learns positional embeddings up to a fixed maximum size.
|
814 |
-
"""
|
815 |
-
|
816 |
-
def __init__(self, num_embeddings, embedding_dim):
|
817 |
-
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
|
818 |
-
# and adjust num_embeddings appropriately. Other models don't have this hack
|
819 |
-
self.offset = 2
|
820 |
-
super().__init__(num_embeddings + self.offset, embedding_dim)
|
821 |
-
|
822 |
-
def forward(self, input_ids_shape, past_key_values_length=0):
|
823 |
-
|
824 |
-
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
|
825 |
-
bsz, seq_len = input_ids_shape[:2]
|
826 |
-
positions = torch.arange(
|
827 |
-
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
|
828 |
-
)
|
829 |
-
return super().forward(positions + self.offset)
|
830 |
-
|
831 |
-
|
832 |
-
class LSGBartEncoderLayer(nn.Module):
|
833 |
|
834 |
def __init__(self, config):
|
835 |
|
836 |
-
super().__init__()
|
837 |
-
self.embed_dim = config.d_model
|
838 |
self.self_attn = LSGBartEncoderAttention(
|
839 |
config=config,
|
840 |
embed_dim=self.embed_dim,
|
841 |
num_heads=config.encoder_attention_heads,
|
842 |
dropout=config.attention_dropout,
|
843 |
)
|
844 |
-
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
845 |
-
self.dropout = config.dropout
|
846 |
-
self.activation_fn = ACT2FN[config.activation_function]
|
847 |
-
self.activation_dropout = config.activation_dropout
|
848 |
-
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
|
849 |
-
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
850 |
-
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
851 |
-
|
852 |
-
def forward(
|
853 |
-
self,
|
854 |
-
hidden_states,
|
855 |
-
attention_mask,
|
856 |
-
layer_head_mask,
|
857 |
-
output_attentions=False,
|
858 |
-
):
|
859 |
-
"""
|
860 |
-
Args:
|
861 |
-
hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
862 |
-
attention_mask (:obj:`torch.FloatTensor`): attention mask of size
|
863 |
-
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
864 |
-
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
|
865 |
-
`(encoder_attention_heads,)`.
|
866 |
-
output_attentions (:obj:`bool`, `optional`):
|
867 |
-
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
868 |
-
returned tensors for more detail.
|
869 |
-
"""
|
870 |
-
residual = hidden_states
|
871 |
-
hidden_states, attn_weights, _ = self.self_attn(
|
872 |
-
hidden_states=hidden_states,
|
873 |
-
attention_mask=attention_mask,
|
874 |
-
layer_head_mask=layer_head_mask,
|
875 |
-
output_attentions=output_attentions,
|
876 |
-
)
|
877 |
-
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
878 |
-
hidden_states = residual + hidden_states
|
879 |
-
hidden_states = self.self_attn_layer_norm(hidden_states)
|
880 |
-
|
881 |
-
residual = hidden_states
|
882 |
-
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
883 |
-
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
884 |
-
hidden_states = self.fc2(hidden_states)
|
885 |
-
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
886 |
-
hidden_states = residual + hidden_states
|
887 |
-
hidden_states = self.final_layer_norm(hidden_states)
|
888 |
-
|
889 |
-
if hidden_states.dtype == torch.float16 and (
|
890 |
-
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
|
891 |
-
):
|
892 |
-
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
893 |
-
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
894 |
-
|
895 |
-
outputs = (hidden_states,)
|
896 |
-
|
897 |
-
if output_attentions:
|
898 |
-
outputs += (attn_weights,)
|
899 |
-
|
900 |
-
return outputs
|
901 |
|
902 |
|
903 |
-
class LSGBartDecoderLayer(
|
904 |
|
905 |
def __init__(self, config):
|
906 |
|
907 |
-
super().__init__()
|
908 |
-
|
909 |
-
|
910 |
-
self.self_attn = LSGBartDecoderAttention(
|
911 |
-
embed_dim=self.embed_dim,
|
912 |
-
num_heads=config.decoder_attention_heads,
|
913 |
-
dropout=config.attention_dropout,
|
914 |
-
is_decoder=True,
|
915 |
-
)
|
916 |
-
self.dropout = config.dropout
|
917 |
-
self.activation_fn = ACT2FN[config.activation_function]
|
918 |
-
self.activation_dropout = config.activation_dropout
|
919 |
-
|
920 |
-
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
921 |
-
self.encoder_attn = LSGBartDecoderAttention(
|
922 |
-
self.embed_dim,
|
923 |
-
config.decoder_attention_heads,
|
924 |
-
dropout=config.attention_dropout,
|
925 |
-
is_decoder=True,
|
926 |
-
)
|
927 |
-
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
928 |
-
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
|
929 |
-
self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
|
930 |
-
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
931 |
-
|
932 |
-
def forward(
|
933 |
-
self,
|
934 |
-
hidden_states,
|
935 |
-
attention_mask=None,
|
936 |
-
encoder_hidden_states=None,
|
937 |
-
encoder_attention_mask=None,
|
938 |
-
layer_head_mask=None,
|
939 |
-
cross_attn_layer_head_mask=None,
|
940 |
-
past_key_value=None,
|
941 |
-
output_attentions=False,
|
942 |
-
use_cache=True,
|
943 |
-
):
|
944 |
-
"""
|
945 |
-
Args:
|
946 |
-
hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
947 |
-
attention_mask (:obj:`torch.FloatTensor`): attention mask of size
|
948 |
-
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
949 |
-
encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
|
950 |
-
encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size
|
951 |
-
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
952 |
-
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
|
953 |
-
`(encoder_attention_heads,)`.
|
954 |
-
cross_attn_layer_head_mask (:obj:`torch.FloatTensor`): mask for cross-attention heads in a given layer of
|
955 |
-
size `(decoder_attention_heads,)`.
|
956 |
-
past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states
|
957 |
-
output_attentions (:obj:`bool`, `optional`):
|
958 |
-
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
|
959 |
-
returned tensors for more detail.
|
960 |
-
"""
|
961 |
-
residual = hidden_states
|
962 |
-
|
963 |
-
# Self Attention
|
964 |
-
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
965 |
-
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
966 |
-
# add present self-attn cache to positions 1,2 of present_key_value tuple
|
967 |
-
|
968 |
-
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
969 |
-
hidden_states=hidden_states,
|
970 |
-
past_key_value=self_attn_past_key_value,
|
971 |
-
attention_mask=attention_mask,
|
972 |
-
layer_head_mask=layer_head_mask,
|
973 |
-
output_attentions=output_attentions,
|
974 |
-
)
|
975 |
-
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
976 |
-
hidden_states = residual + hidden_states
|
977 |
-
hidden_states = self.self_attn_layer_norm(hidden_states)
|
978 |
-
|
979 |
-
# Cross-Attention Block
|
980 |
-
cross_attn_present_key_value = None
|
981 |
-
cross_attn_weights = None
|
982 |
-
if encoder_hidden_states is not None:
|
983 |
-
residual = hidden_states
|
984 |
-
|
985 |
-
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
|
986 |
-
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
987 |
-
|
988 |
-
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
|
989 |
-
hidden_states=hidden_states,
|
990 |
-
key_value_states=encoder_hidden_states,
|
991 |
-
attention_mask=encoder_attention_mask,
|
992 |
-
layer_head_mask=cross_attn_layer_head_mask,
|
993 |
-
past_key_value=cross_attn_past_key_value,
|
994 |
-
output_attentions=output_attentions,
|
995 |
-
)
|
996 |
-
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
997 |
-
hidden_states = residual + hidden_states
|
998 |
-
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
999 |
-
|
1000 |
-
# add cross-attn to positions 3,4 of present_key_value tuple
|
1001 |
-
present_key_value = present_key_value + cross_attn_present_key_value
|
1002 |
-
|
1003 |
-
# Fully Connected
|
1004 |
-
residual = hidden_states
|
1005 |
-
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
1006 |
-
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
1007 |
-
hidden_states = self.fc2(hidden_states)
|
1008 |
-
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
1009 |
-
hidden_states = residual + hidden_states
|
1010 |
-
hidden_states = self.final_layer_norm(hidden_states)
|
1011 |
-
|
1012 |
-
outputs = (hidden_states,)
|
1013 |
-
|
1014 |
-
if output_attentions:
|
1015 |
-
outputs += (self_attn_weights, cross_attn_weights)
|
1016 |
-
|
1017 |
-
if use_cache:
|
1018 |
-
outputs += (present_key_value,)
|
1019 |
-
|
1020 |
-
return outputs
|
1021 |
-
|
1022 |
|
1023 |
-
class LSGBartClassificationHead(
|
1024 |
"""Head for sentence-level classification tasks."""
|
1025 |
|
1026 |
def __init__(
|
@@ -1031,55 +649,18 @@ class LSGBartClassificationHead(nn.Module):
|
|
1031 |
pooler_dropout,
|
1032 |
):
|
1033 |
|
1034 |
-
super().__init__()
|
1035 |
-
self.dense = nn.Linear(input_dim, inner_dim)
|
1036 |
-
self.dropout = nn.Dropout(p=pooler_dropout)
|
1037 |
-
self.out_proj = nn.Linear(inner_dim, num_classes)
|
1038 |
-
|
1039 |
-
def forward(self, hidden_states):
|
1040 |
-
|
1041 |
-
hidden_states = self.dropout(hidden_states)
|
1042 |
-
hidden_states = self.dense(hidden_states)
|
1043 |
-
hidden_states = torch.tanh(hidden_states)
|
1044 |
-
hidden_states = self.dropout(hidden_states)
|
1045 |
-
hidden_states = self.out_proj(hidden_states)
|
1046 |
-
return hidden_states
|
1047 |
|
1048 |
|
1049 |
-
class LSGBartPretrainedModel(
|
1050 |
|
1051 |
config_class = LSGBartConfig
|
1052 |
-
base_model_prefix = "model"
|
1053 |
-
supports_gradient_checkpointing = True
|
1054 |
-
_keys_to_ignore_on_load_unexpected = [r"encoder\.version", r"decoder\.version"]
|
1055 |
-
|
1056 |
-
def _init_weights(self, module):
|
1057 |
-
|
1058 |
-
std = self.config.init_std
|
1059 |
-
if isinstance(module, nn.Linear):
|
1060 |
-
module.weight.data.normal_(mean=0.0, std=std)
|
1061 |
-
if module.bias is not None:
|
1062 |
-
module.bias.data.zero_()
|
1063 |
-
elif isinstance(module, nn.Embedding):
|
1064 |
-
module.weight.data.normal_(mean=0.0, std=std)
|
1065 |
-
if module.padding_idx is not None:
|
1066 |
-
module.weight.data[module.padding_idx].zero_()
|
1067 |
|
1068 |
def _set_gradient_checkpointing(self, module, value=False):
|
1069 |
|
1070 |
-
if isinstance(module, (LSGBartDecoder, LSGBartEncoder)):
|
1071 |
module.gradient_checkpointing = value
|
1072 |
|
1073 |
-
@property
|
1074 |
-
def dummy_inputs(self):
|
1075 |
-
pad_token = self.config.pad_token_id
|
1076 |
-
input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
|
1077 |
-
dummy_inputs = {
|
1078 |
-
"attention_mask": input_ids.ne(pad_token),
|
1079 |
-
"input_ids": input_ids,
|
1080 |
-
}
|
1081 |
-
return dummy_inputs
|
1082 |
-
|
1083 |
|
1084 |
class PretrainedLSGBartModel(LSGBartPretrainedModel):
|
1085 |
|
@@ -1090,7 +671,7 @@ class PretrainedLSGBartModel(LSGBartPretrainedModel):
|
|
1090 |
)
|
1091 |
|
1092 |
|
1093 |
-
class LSGBartEncoder(LSGBartPretrainedModel):
|
1094 |
"""
|
1095 |
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
|
1096 |
:class:`BartEncoderLayer`.
|
@@ -1115,7 +696,7 @@ class LSGBartEncoder(LSGBartPretrainedModel):
|
|
1115 |
else:
|
1116 |
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
|
1117 |
|
1118 |
-
self.embed_positions =
|
1119 |
config.max_position_embeddings,
|
1120 |
embed_dim,
|
1121 |
)
|
@@ -1140,12 +721,6 @@ class LSGBartEncoder(LSGBartPretrainedModel):
|
|
1140 |
# Initialize weights and apply final processing
|
1141 |
self.post_init()
|
1142 |
|
1143 |
-
def get_input_embeddings(self):
|
1144 |
-
return self.embed_tokens
|
1145 |
-
|
1146 |
-
def set_input_embeddings(self, value):
|
1147 |
-
self.embed_tokens = value
|
1148 |
-
|
1149 |
def forward(self,
|
1150 |
input_ids=None,
|
1151 |
attention_mask=None,
|
@@ -1335,7 +910,7 @@ class LSGBartDecoder(BartDecoder, LSGBartPretrainedModel):
|
|
1335 |
else:
|
1336 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
|
1337 |
|
1338 |
-
self.embed_positions =
|
1339 |
config.max_position_embeddings,
|
1340 |
config.d_model,
|
1341 |
)
|
@@ -1348,36 +923,24 @@ class LSGBartDecoder(BartDecoder, LSGBartPretrainedModel):
|
|
1348 |
self.post_init()
|
1349 |
|
1350 |
|
1351 |
-
class LSGBartModel(LSGBartPretrainedModel):
|
1352 |
|
1353 |
def __init__(self, config):
|
1354 |
|
1355 |
-
|
1356 |
|
1357 |
padding_idx, vocab_size = config.pad_token_id, config.vocab_size
|
1358 |
self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
|
|
|
1359 |
self.pass_global_tokens_to_decoder = config.pass_global_tokens_to_decoder
|
1360 |
self.num_global_tokens = config.num_global_tokens
|
|
|
1361 |
self.encoder = LSGBartEncoder(config, self.shared)
|
1362 |
self.decoder = LSGBartDecoder(config, self.shared)
|
1363 |
|
1364 |
# Initialize weights and apply final processing
|
1365 |
self.post_init()
|
1366 |
|
1367 |
-
def get_input_embeddings(self):
|
1368 |
-
return self.shared
|
1369 |
-
|
1370 |
-
def set_input_embeddings(self, value):
|
1371 |
-
self.shared = value
|
1372 |
-
self.encoder.embed_tokens = self.shared
|
1373 |
-
self.decoder.embed_tokens = self.shared
|
1374 |
-
|
1375 |
-
def get_encoder(self):
|
1376 |
-
return self.encoder
|
1377 |
-
|
1378 |
-
def get_decoder(self):
|
1379 |
-
return self.decoder
|
1380 |
-
|
1381 |
def forward(
|
1382 |
self,
|
1383 |
input_ids=None,
|
|
|
81 |
assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
|
82 |
|
83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
class BaseSelfAttention(nn.Module):
|
85 |
|
86 |
def __init__(
|
|
|
618 |
return x.reshape(n, h, -1, chunk_size, d)
|
619 |
|
620 |
|
621 |
+
class LSGBartEncoderLayer(BartEncoderLayer):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
622 |
|
623 |
def __init__(self, config):
|
624 |
|
625 |
+
super().__init__(config)
|
|
|
626 |
self.self_attn = LSGBartEncoderAttention(
|
627 |
config=config,
|
628 |
embed_dim=self.embed_dim,
|
629 |
num_heads=config.encoder_attention_heads,
|
630 |
dropout=config.attention_dropout,
|
631 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
632 |
|
633 |
|
634 |
+
class LSGBartDecoderLayer(BartDecoderLayer):
|
635 |
|
636 |
def __init__(self, config):
|
637 |
|
638 |
+
super().__init__(config)
|
639 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
640 |
|
641 |
+
class LSGBartClassificationHead(BartClassificationHead):
|
642 |
"""Head for sentence-level classification tasks."""
|
643 |
|
644 |
def __init__(
|
|
|
649 |
pooler_dropout,
|
650 |
):
|
651 |
|
652 |
+
super().__init__(input_dim, inner_dim, num_classes, pooler_dropout)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
653 |
|
654 |
|
655 |
+
class LSGBartPretrainedModel(BartPretrainedModel):
|
656 |
|
657 |
config_class = LSGBartConfig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
658 |
|
659 |
def _set_gradient_checkpointing(self, module, value=False):
|
660 |
|
661 |
+
if isinstance(module, (BartDecoder, BartEncoder, LSGBartDecoder, LSGBartEncoder)):
|
662 |
module.gradient_checkpointing = value
|
663 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
664 |
|
665 |
class PretrainedLSGBartModel(LSGBartPretrainedModel):
|
666 |
|
|
|
671 |
)
|
672 |
|
673 |
|
674 |
+
class LSGBartEncoder(LSGBartPretrainedModel, BartEncoder):
|
675 |
"""
|
676 |
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
|
677 |
:class:`BartEncoderLayer`.
|
|
|
696 |
else:
|
697 |
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
|
698 |
|
699 |
+
self.embed_positions = BartLearnedPositionalEmbedding(
|
700 |
config.max_position_embeddings,
|
701 |
embed_dim,
|
702 |
)
|
|
|
721 |
# Initialize weights and apply final processing
|
722 |
self.post_init()
|
723 |
|
|
|
|
|
|
|
|
|
|
|
|
|
724 |
def forward(self,
|
725 |
input_ids=None,
|
726 |
attention_mask=None,
|
|
|
910 |
else:
|
911 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
|
912 |
|
913 |
+
self.embed_positions = BartLearnedPositionalEmbedding(
|
914 |
config.max_position_embeddings,
|
915 |
config.d_model,
|
916 |
)
|
|
|
923 |
self.post_init()
|
924 |
|
925 |
|
926 |
+
class LSGBartModel(LSGBartPretrainedModel, BartModel):
|
927 |
|
928 |
def __init__(self, config):
|
929 |
|
930 |
+
LSGBartPretrainedModel.__init__(self, config)
|
931 |
|
932 |
padding_idx, vocab_size = config.pad_token_id, config.vocab_size
|
933 |
self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
|
934 |
+
|
935 |
self.pass_global_tokens_to_decoder = config.pass_global_tokens_to_decoder
|
936 |
self.num_global_tokens = config.num_global_tokens
|
937 |
+
|
938 |
self.encoder = LSGBartEncoder(config, self.shared)
|
939 |
self.decoder = LSGBartDecoder(config, self.shared)
|
940 |
|
941 |
# Initialize weights and apply final processing
|
942 |
self.post_init()
|
943 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
944 |
def forward(
|
945 |
self,
|
946 |
input_ids=None,
|