replace -1e4 masks
Browse files- modeling_lsg_bart.py +12 -14
modeling_lsg_bart.py
CHANGED
@@ -3,7 +3,6 @@ import torch
|
|
3 |
from transformers.models.bart.modeling_bart import *
|
4 |
from transformers.models.bart.modeling_bart import _expand_mask
|
5 |
import torch.nn as nn
|
6 |
-
from torch.nn import BCEWithLogitsLoss
|
7 |
import sys
|
8 |
|
9 |
AUTO_MAP = {
|
@@ -16,7 +15,7 @@ AUTO_MAP = {
|
|
16 |
|
17 |
class LSGBartConfig(BartConfig):
|
18 |
"""
|
19 |
-
This class overrides :class:`~transformers.
|
20 |
documentation alongside usage examples.
|
21 |
"""
|
22 |
|
@@ -266,8 +265,8 @@ class LSGAttentionProduct(nn.Module):
|
|
266 |
s = (size - step) // 2
|
267 |
|
268 |
# Pad before block reshaping
|
269 |
-
if is_attn_mask:
|
270 |
-
pad_value =
|
271 |
hidden_states = hidden_states.transpose(-1, -2)
|
272 |
else:
|
273 |
pad_value = 0
|
@@ -296,7 +295,7 @@ class LSGAttentionProduct(nn.Module):
|
|
296 |
|
297 |
# Pad before block reshaping
|
298 |
if is_attn_mask:
|
299 |
-
pad_value =
|
300 |
hidden_states = hidden_states.transpose(-1, -2)
|
301 |
else:
|
302 |
pad_value = 0
|
@@ -425,7 +424,7 @@ class LSGBartEncoderAttention(BaseSelfAttention):
|
|
425 |
keys = keys.sum(dim=-2) / (mask + 1e-6)
|
426 |
values = values.sum(dim=-2) / (mask + 1e-6)
|
427 |
|
428 |
-
mask =
|
429 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
|
430 |
|
431 |
def get_sparse_tokens_with_stride(self, keys, values, mask):
|
@@ -490,8 +489,7 @@ class LSGBartEncoderAttention(BaseSelfAttention):
|
|
490 |
keys /= mask + 1e-8
|
491 |
values /= mask + 1e-8
|
492 |
|
493 |
-
mask =
|
494 |
-
|
495 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
|
496 |
|
497 |
def lsh_round(self, keys, values, mask, output_size):
|
@@ -739,7 +737,7 @@ class LSGBartEncoder(LSGBartPretrainedModel, BartEncoder):
|
|
739 |
n, t = inputs_.size()[:2]
|
740 |
|
741 |
if attention_mask is None:
|
742 |
-
attention_mask = torch.ones(n, t, device=inputs_.device)
|
743 |
if self.mask_first_token:
|
744 |
attention_mask[:, 0] = 0
|
745 |
|
@@ -891,7 +889,7 @@ class LSGBartEncoder(LSGBartPretrainedModel, BartEncoder):
|
|
891 |
)
|
892 |
|
893 |
|
894 |
-
class LSGBartDecoder(
|
895 |
"""
|
896 |
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`LSGBartDecoderLayer`
|
897 |
Args:
|
@@ -1032,7 +1030,7 @@ class LSGBartModel(LSGBartPretrainedModel, BartModel):
|
|
1032 |
)
|
1033 |
|
1034 |
|
1035 |
-
class LSGBartForConditionalGeneration(
|
1036 |
|
1037 |
base_model_prefix = "model"
|
1038 |
_keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head\.weight"]
|
@@ -1048,7 +1046,7 @@ class LSGBartForConditionalGeneration(BartForConditionalGeneration, LSGBartPretr
|
|
1048 |
self.post_init()
|
1049 |
|
1050 |
|
1051 |
-
class LSGBartForSequenceClassification(
|
1052 |
|
1053 |
def __init__(self, config: LSGBartConfig, **kwargs):
|
1054 |
|
@@ -1064,7 +1062,7 @@ class LSGBartForSequenceClassification(BartForSequenceClassification, LSGBartPre
|
|
1064 |
self.model._init_weights(self.classification_head.out_proj)
|
1065 |
|
1066 |
|
1067 |
-
class LSGBartForQuestionAnswering(
|
1068 |
|
1069 |
def __init__(self, config: LSGBartConfig):
|
1070 |
|
@@ -1093,7 +1091,7 @@ class LSGBartDecoderWrapper(LSGBartPretrainedModel):
|
|
1093 |
return self.decoder(*args, **kwargs)
|
1094 |
|
1095 |
|
1096 |
-
class LSGBartForCausalLM(
|
1097 |
|
1098 |
def __init__(self, config: LSGBartConfig):
|
1099 |
|
|
|
3 |
from transformers.models.bart.modeling_bart import *
|
4 |
from transformers.models.bart.modeling_bart import _expand_mask
|
5 |
import torch.nn as nn
|
|
|
6 |
import sys
|
7 |
|
8 |
AUTO_MAP = {
|
|
|
15 |
|
16 |
class LSGBartConfig(BartConfig):
|
17 |
"""
|
18 |
+
This class overrides :class:`~transformers.BartConfig`. Please check the superclass for the appropriate
|
19 |
documentation alongside usage examples.
|
20 |
"""
|
21 |
|
|
|
265 |
s = (size - step) // 2
|
266 |
|
267 |
# Pad before block reshaping
|
268 |
+
if is_attn_mask:
|
269 |
+
pad_value = torch.finfo(hidden_states.dtype).min
|
270 |
hidden_states = hidden_states.transpose(-1, -2)
|
271 |
else:
|
272 |
pad_value = 0
|
|
|
295 |
|
296 |
# Pad before block reshaping
|
297 |
if is_attn_mask:
|
298 |
+
pad_value = torch.finfo(hidden_states.dtype).min
|
299 |
hidden_states = hidden_states.transpose(-1, -2)
|
300 |
else:
|
301 |
pad_value = 0
|
|
|
424 |
keys = keys.sum(dim=-2) / (mask + 1e-6)
|
425 |
values = values.sum(dim=-2) / (mask + 1e-6)
|
426 |
|
427 |
+
mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
|
428 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
|
429 |
|
430 |
def get_sparse_tokens_with_stride(self, keys, values, mask):
|
|
|
489 |
keys /= mask + 1e-8
|
490 |
values /= mask + 1e-8
|
491 |
|
492 |
+
mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
|
|
|
493 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
|
494 |
|
495 |
def lsh_round(self, keys, values, mask, output_size):
|
|
|
737 |
n, t = inputs_.size()[:2]
|
738 |
|
739 |
if attention_mask is None:
|
740 |
+
attention_mask = torch.ones(n, t, device=inputs_.device, dtype=inputs_.dtype)
|
741 |
if self.mask_first_token:
|
742 |
attention_mask[:, 0] = 0
|
743 |
|
|
|
889 |
)
|
890 |
|
891 |
|
892 |
+
class LSGBartDecoder(LSGBartPretrainedModel, BartDecoder):
|
893 |
"""
|
894 |
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`LSGBartDecoderLayer`
|
895 |
Args:
|
|
|
1030 |
)
|
1031 |
|
1032 |
|
1033 |
+
class LSGBartForConditionalGeneration(LSGBartPretrainedModel, BartForConditionalGeneration):
|
1034 |
|
1035 |
base_model_prefix = "model"
|
1036 |
_keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head\.weight"]
|
|
|
1046 |
self.post_init()
|
1047 |
|
1048 |
|
1049 |
+
class LSGBartForSequenceClassification(LSGBartPretrainedModel, BartForSequenceClassification):
|
1050 |
|
1051 |
def __init__(self, config: LSGBartConfig, **kwargs):
|
1052 |
|
|
|
1062 |
self.model._init_weights(self.classification_head.out_proj)
|
1063 |
|
1064 |
|
1065 |
+
class LSGBartForQuestionAnswering(LSGBartPretrainedModel, BartForQuestionAnswering):
|
1066 |
|
1067 |
def __init__(self, config: LSGBartConfig):
|
1068 |
|
|
|
1091 |
return self.decoder(*args, **kwargs)
|
1092 |
|
1093 |
|
1094 |
+
class LSGBartForCausalLM(LSGBartPretrainedModel, BartForCausalLM):
|
1095 |
|
1096 |
def __init__(self, config: LSGBartConfig):
|
1097 |
|