small fix with torch.finfo
Browse files- modeling_lsg_pegasus.py +26 -65
modeling_lsg_pegasus.py
CHANGED
@@ -13,7 +13,7 @@ AUTO_MAP = {
|
|
13 |
|
14 |
class LSGPegasusConfig(PegasusConfig):
|
15 |
"""
|
16 |
-
This class overrides :class:`~transformers.
|
17 |
documentation alongside usage examples.
|
18 |
"""
|
19 |
|
@@ -55,7 +55,8 @@ class LSGPegasusConfig(PegasusConfig):
|
|
55 |
|
56 |
if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride"]:
|
57 |
logger.warning(
|
58 |
-
"[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride'],
|
|
|
59 |
self.sparsity_type = None
|
60 |
|
61 |
if self.sparsity_type in ["stride", "block_stride"]:
|
@@ -63,7 +64,7 @@ class LSGPegasusConfig(PegasusConfig):
|
|
63 |
logger.warning(
|
64 |
"[WARNING CONFIG]: sparsity_factor > encoder_attention_heads is not recommended for stride/block_stride sparsity"
|
65 |
)
|
66 |
-
|
67 |
if self.num_global_tokens < 1:
|
68 |
logger.warning(
|
69 |
"[WARNING CONFIG]: num_global_tokens < 1 is not compatible, setting num_global_tokens=1"
|
@@ -71,13 +72,23 @@ class LSGPegasusConfig(PegasusConfig):
|
|
71 |
self.num_global_tokens = 1
|
72 |
elif self.num_global_tokens > 512:
|
73 |
logger.warning(
|
74 |
-
"[WARNING CONFIG]: num_global_tokens > 512 is not
|
75 |
)
|
76 |
self.num_global_tokens = 512
|
77 |
-
|
78 |
if self.sparsity_factor > 0:
|
79 |
assert self.block_size % self.sparsity_factor == 0, "[ERROR CONFIG]: block_size must be divisible by sparsity_factor"
|
80 |
assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
|
83 |
class BaseSelfAttention(nn.Module):
|
@@ -422,7 +433,8 @@ class LSGPegasusEncoderAttention(BaseSelfAttention):
|
|
422 |
keys = keys.sum(dim=-2) / (mask + 1e-6)
|
423 |
values = values.sum(dim=-2) / (mask + 1e-6)
|
424 |
|
425 |
-
mask = (1. - mask.clamp(0, 1))
|
|
|
426 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
|
427 |
|
428 |
def get_sparse_tokens_with_stride(self, keys, values, mask):
|
@@ -487,7 +499,8 @@ class LSGPegasusEncoderAttention(BaseSelfAttention):
|
|
487 |
keys /= mask + 1e-8
|
488 |
values /= mask + 1e-8
|
489 |
|
490 |
-
mask = (1. - mask.clamp(0, 1))
|
|
|
491 |
|
492 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
|
493 |
|
@@ -556,8 +569,6 @@ class LSGPegasusEncoderAttention(BaseSelfAttention):
|
|
556 |
attention_mask=attention_mask
|
557 |
)
|
558 |
|
559 |
-
if head_mask is not None:
|
560 |
-
context_layer = context_layer * head_mask[:, :, :1, :1]
|
561 |
return self.reshape_output(context_layer)
|
562 |
|
563 |
# Split input into global tokens and other tokens
|
@@ -605,8 +616,6 @@ class LSGPegasusEncoderAttention(BaseSelfAttention):
|
|
605 |
|
606 |
# Merge global and local-sparse tokens
|
607 |
context_layer = torch.cat([bos, context_layer], dim=-2)
|
608 |
-
if head_mask is not None:
|
609 |
-
context_layer = context_layer * head_mask[:, :, :1, :1]
|
610 |
context_layer = self.reshape_output(context_layer)
|
611 |
|
612 |
return context_layer
|
@@ -665,21 +674,13 @@ class LSGPegasusEncoderLayer(PegasusEncoderLayer):
|
|
665 |
dropout=config.attention_dropout,
|
666 |
)
|
667 |
|
668 |
-
|
669 |
-
# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Pegasus
|
670 |
-
class LSGPegasusDecoderLayer(PegasusDecoderLayer):
|
671 |
-
|
672 |
-
def __init__(self, config: LSGPegasusConfig):
|
673 |
-
|
674 |
-
super().__init__(config)
|
675 |
|
676 |
-
|
677 |
class LSGPegasusPreTrainedModel(PegasusPreTrainedModel):
|
678 |
|
679 |
config_class = LSGPegasusConfig
|
680 |
|
681 |
def _set_gradient_checkpointing(self, module, value=False):
|
682 |
-
if isinstance(module, (PegasusDecoder, PegasusEncoder,
|
683 |
module.gradient_checkpointing = value
|
684 |
|
685 |
|
@@ -922,44 +923,6 @@ class LSGPegasusEncoder(LSGPegasusPreTrainedModel, PegasusEncoder):
|
|
922 |
)
|
923 |
|
924 |
|
925 |
-
class LSGPegasusDecoder(LSGPegasusPreTrainedModel, PegasusDecoder):
|
926 |
-
"""
|
927 |
-
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`PegasusDecoderLayer`
|
928 |
-
Args:
|
929 |
-
config: PegasusConfig
|
930 |
-
embed_tokens (nn.Embedding): output embedding
|
931 |
-
"""
|
932 |
-
|
933 |
-
def __init__(self, config: LSGPegasusConfig, embed_tokens: Optional[nn.Embedding] = None):
|
934 |
-
|
935 |
-
LSGPegasusPreTrainedModel.__init__(self, config)
|
936 |
-
|
937 |
-
self.dropout = config.dropout
|
938 |
-
self.layerdrop = config.decoder_layerdrop
|
939 |
-
self.padding_idx = config.pad_token_id
|
940 |
-
self.max_target_positions = config.max_position_embeddings
|
941 |
-
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
|
942 |
-
self.adaptive = config.adaptive
|
943 |
-
|
944 |
-
if embed_tokens is not None:
|
945 |
-
self.embed_tokens = embed_tokens
|
946 |
-
else:
|
947 |
-
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
|
948 |
-
|
949 |
-
self.embed_positions = LSGPegasusSinusoidalPositionalEmbedding(
|
950 |
-
config.max_position_embeddings,
|
951 |
-
config.d_model,
|
952 |
-
self.padding_idx,
|
953 |
-
)
|
954 |
-
self.layers = nn.ModuleList([LSGPegasusDecoderLayer(config) for _ in range(config.decoder_layers)])
|
955 |
-
self.layer_norm = nn.LayerNorm(config.d_model)
|
956 |
-
|
957 |
-
self.gradient_checkpointing = False
|
958 |
-
|
959 |
-
# Initialize weights and apply final processing
|
960 |
-
self.post_init()
|
961 |
-
|
962 |
-
|
963 |
class LSGPegasusModel(LSGPegasusPreTrainedModel, PegasusModel):
|
964 |
|
965 |
def __init__(self, config: LSGPegasusConfig):
|
@@ -971,7 +934,7 @@ class LSGPegasusModel(LSGPegasusPreTrainedModel, PegasusModel):
|
|
971 |
self.pass_global_tokens_to_decoder = config.pass_global_tokens_to_decoder
|
972 |
self.num_global_tokens = config.num_global_tokens
|
973 |
self.encoder = LSGPegasusEncoder(config, self.shared)
|
974 |
-
self.decoder =
|
975 |
|
976 |
# Initialize weights and apply final processing
|
977 |
self.post_init()
|
@@ -1078,6 +1041,7 @@ class LSGPegasusForConditionalGeneration(LSGPegasusPreTrainedModel, PegasusForCo
|
|
1078 |
]
|
1079 |
|
1080 |
def __init__(self, config: LSGPegasusConfig):
|
|
|
1081 |
LSGPegasusPreTrainedModel.__init__(self, config)
|
1082 |
self.model = LSGPegasusModel(config)
|
1083 |
self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
|
@@ -1088,18 +1052,15 @@ class LSGPegasusForConditionalGeneration(LSGPegasusPreTrainedModel, PegasusForCo
|
|
1088 |
|
1089 |
|
1090 |
# Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->Pegasus
|
1091 |
-
class LSGPegasusDecoderWrapper(LSGPegasusPreTrainedModel):
|
1092 |
"""
|
1093 |
This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
|
1094 |
used in combination with the :class:`~transformers.EncoderDecoderModel` framework.
|
1095 |
"""
|
1096 |
|
1097 |
def __init__(self, config):
|
1098 |
-
|
1099 |
-
self
|
1100 |
-
|
1101 |
-
def forward(self, *args, **kwargs):
|
1102 |
-
return self.decoder(*args, **kwargs)
|
1103 |
|
1104 |
|
1105 |
class LSGPegasusForCausalLM(LSGPegasusPreTrainedModel, PegasusForCausalLM):
|
|
|
13 |
|
14 |
class LSGPegasusConfig(PegasusConfig):
|
15 |
"""
|
16 |
+
This class overrides :class:`~transformers.PegasusConfig`. Please check the superclass for the appropriate
|
17 |
documentation alongside usage examples.
|
18 |
"""
|
19 |
|
|
|
55 |
|
56 |
if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride"]:
|
57 |
logger.warning(
|
58 |
+
"[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride'], \
|
59 |
+
setting sparsity_type=None, computation will skip sparse attention")
|
60 |
self.sparsity_type = None
|
61 |
|
62 |
if self.sparsity_type in ["stride", "block_stride"]:
|
|
|
64 |
logger.warning(
|
65 |
"[WARNING CONFIG]: sparsity_factor > encoder_attention_heads is not recommended for stride/block_stride sparsity"
|
66 |
)
|
67 |
+
|
68 |
if self.num_global_tokens < 1:
|
69 |
logger.warning(
|
70 |
"[WARNING CONFIG]: num_global_tokens < 1 is not compatible, setting num_global_tokens=1"
|
|
|
72 |
self.num_global_tokens = 1
|
73 |
elif self.num_global_tokens > 512:
|
74 |
logger.warning(
|
75 |
+
"[WARNING CONFIG]: num_global_tokens > 512 is not allowed, setting num_global_tokens=512"
|
76 |
)
|
77 |
self.num_global_tokens = 512
|
78 |
+
|
79 |
if self.sparsity_factor > 0:
|
80 |
assert self.block_size % self.sparsity_factor == 0, "[ERROR CONFIG]: block_size must be divisible by sparsity_factor"
|
81 |
assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
|
82 |
+
|
83 |
+
if self.mask_first_token and not pool_with_global:
|
84 |
+
logger.warning(
|
85 |
+
"[WARNING CONFIG]: pool_with_global==False is not compatible with mask_first_token==True. Setting pool_with_global to True.")
|
86 |
+
self.pool_with_global = True
|
87 |
+
|
88 |
+
if hasattr(self, "position_embedding_type"):
|
89 |
+
if self.position_embedding_type != "absolute":
|
90 |
+
logger.warning(
|
91 |
+
"[WARNING CONFIG]: LSG Attention is not compatible with relative positional embedding and will skip its computation. Set position_embedding_type='absolute' to remove this warning.")
|
92 |
|
93 |
|
94 |
class BaseSelfAttention(nn.Module):
|
|
|
433 |
keys = keys.sum(dim=-2) / (mask + 1e-6)
|
434 |
values = values.sum(dim=-2) / (mask + 1e-6)
|
435 |
|
436 |
+
mask = (1. - mask.clamp(0, 1))
|
437 |
+
mask *= torch.finfo(mask.dtype).min
|
438 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
|
439 |
|
440 |
def get_sparse_tokens_with_stride(self, keys, values, mask):
|
|
|
499 |
keys /= mask + 1e-8
|
500 |
values /= mask + 1e-8
|
501 |
|
502 |
+
mask = (1. - mask.clamp(0, 1))
|
503 |
+
mask *= torch.finfo(mask.dtype).min
|
504 |
|
505 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
|
506 |
|
|
|
569 |
attention_mask=attention_mask
|
570 |
)
|
571 |
|
|
|
|
|
572 |
return self.reshape_output(context_layer)
|
573 |
|
574 |
# Split input into global tokens and other tokens
|
|
|
616 |
|
617 |
# Merge global and local-sparse tokens
|
618 |
context_layer = torch.cat([bos, context_layer], dim=-2)
|
|
|
|
|
619 |
context_layer = self.reshape_output(context_layer)
|
620 |
|
621 |
return context_layer
|
|
|
674 |
dropout=config.attention_dropout,
|
675 |
)
|
676 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
677 |
|
|
|
678 |
class LSGPegasusPreTrainedModel(PegasusPreTrainedModel):
|
679 |
|
680 |
config_class = LSGPegasusConfig
|
681 |
|
682 |
def _set_gradient_checkpointing(self, module, value=False):
|
683 |
+
if isinstance(module, (PegasusDecoder, PegasusEncoder, LSGPegasusEncoder)):
|
684 |
module.gradient_checkpointing = value
|
685 |
|
686 |
|
|
|
923 |
)
|
924 |
|
925 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
926 |
class LSGPegasusModel(LSGPegasusPreTrainedModel, PegasusModel):
|
927 |
|
928 |
def __init__(self, config: LSGPegasusConfig):
|
|
|
934 |
self.pass_global_tokens_to_decoder = config.pass_global_tokens_to_decoder
|
935 |
self.num_global_tokens = config.num_global_tokens
|
936 |
self.encoder = LSGPegasusEncoder(config, self.shared)
|
937 |
+
self.decoder = PegasusDecoder(config, self.shared)
|
938 |
|
939 |
# Initialize weights and apply final processing
|
940 |
self.post_init()
|
|
|
1041 |
]
|
1042 |
|
1043 |
def __init__(self, config: LSGPegasusConfig):
|
1044 |
+
|
1045 |
LSGPegasusPreTrainedModel.__init__(self, config)
|
1046 |
self.model = LSGPegasusModel(config)
|
1047 |
self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
|
|
|
1052 |
|
1053 |
|
1054 |
# Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->Pegasus
|
1055 |
+
class LSGPegasusDecoderWrapper(LSGPegasusPreTrainedModel, PegasusDecoderWrapper):
|
1056 |
"""
|
1057 |
This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
|
1058 |
used in combination with the :class:`~transformers.EncoderDecoderModel` framework.
|
1059 |
"""
|
1060 |
|
1061 |
def __init__(self, config):
|
1062 |
+
LSGPegasusPreTrainedModel.__init__(self, config)
|
1063 |
+
PegasusDecoderWrapper.__init__(self, config)
|
|
|
|
|
|
|
1064 |
|
1065 |
|
1066 |
class LSGPegasusForCausalLM(LSGPegasusPreTrainedModel, PegasusForCausalLM):
|