ccdv commited on
Commit
b19ad10
1 Parent(s): 5f9bbab

replace 1e4 mask

Browse files
Files changed (2) hide show
  1. README.md +1 -0
  2. modeling_lsg_bart.py +12 -14
README.md CHANGED
@@ -45,6 +45,7 @@ You can change various parameters like :
45
  * local block size (block_size=128)
46
  * sparse block size (sparse_block_size=128)
47
  * sparsity factor (sparsity_factor=2)
 
48
  * see config.json file
49
 
50
  Default parameters work well in practice. If you are short on memory, reduce block sizes, increase sparsity factor and remove dropout in the attention score matrix.
 
45
  * local block size (block_size=128)
46
  * sparse block size (sparse_block_size=128)
47
  * sparsity factor (sparsity_factor=2)
48
+ * mask_first_token (mask first token since it is redundant with the first global token)
49
  * see config.json file
50
 
51
  Default parameters work well in practice. If you are short on memory, reduce block sizes, increase sparsity factor and remove dropout in the attention score matrix.
modeling_lsg_bart.py CHANGED
@@ -3,7 +3,6 @@ import torch
3
  from transformers.models.bart.modeling_bart import *
4
  from transformers.models.bart.modeling_bart import _expand_mask
5
  import torch.nn as nn
6
- from torch.nn import BCEWithLogitsLoss
7
  import sys
8
 
9
  AUTO_MAP = {
@@ -16,7 +15,7 @@ AUTO_MAP = {
16
 
17
  class LSGBartConfig(BartConfig):
18
  """
19
- This class overrides :class:`~transformers.RobertaConfig`. Please check the superclass for the appropriate
20
  documentation alongside usage examples.
21
  """
22
 
@@ -266,8 +265,8 @@ class LSGAttentionProduct(nn.Module):
266
  s = (size - step) // 2
267
 
268
  # Pad before block reshaping
269
- if is_attn_mask:
270
- pad_value = -10000
271
  hidden_states = hidden_states.transpose(-1, -2)
272
  else:
273
  pad_value = 0
@@ -296,7 +295,7 @@ class LSGAttentionProduct(nn.Module):
296
 
297
  # Pad before block reshaping
298
  if is_attn_mask:
299
- pad_value = -10000
300
  hidden_states = hidden_states.transpose(-1, -2)
301
  else:
302
  pad_value = 0
@@ -425,7 +424,7 @@ class LSGBartEncoderAttention(BaseSelfAttention):
425
  keys = keys.sum(dim=-2) / (mask + 1e-6)
426
  values = values.sum(dim=-2) / (mask + 1e-6)
427
 
428
- mask = - (1. - mask.clamp(0, 1)) * 1e4
429
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
430
 
431
  def get_sparse_tokens_with_stride(self, keys, values, mask):
@@ -490,8 +489,7 @@ class LSGBartEncoderAttention(BaseSelfAttention):
490
  keys /= mask + 1e-8
491
  values /= mask + 1e-8
492
 
493
- mask = -10000 * (1. - mask.clamp(0, 1))
494
-
495
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
496
 
497
  def lsh_round(self, keys, values, mask, output_size):
@@ -739,7 +737,7 @@ class LSGBartEncoder(LSGBartPretrainedModel, BartEncoder):
739
  n, t = inputs_.size()[:2]
740
 
741
  if attention_mask is None:
742
- attention_mask = torch.ones(n, t, device=inputs_.device)
743
  if self.mask_first_token:
744
  attention_mask[:, 0] = 0
745
 
@@ -891,7 +889,7 @@ class LSGBartEncoder(LSGBartPretrainedModel, BartEncoder):
891
  )
892
 
893
 
894
- class LSGBartDecoder(BartDecoder, LSGBartPretrainedModel):
895
  """
896
  Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`LSGBartDecoderLayer`
897
  Args:
@@ -1032,7 +1030,7 @@ class LSGBartModel(LSGBartPretrainedModel, BartModel):
1032
  )
1033
 
1034
 
1035
- class LSGBartForConditionalGeneration(BartForConditionalGeneration, LSGBartPretrainedModel):
1036
 
1037
  base_model_prefix = "model"
1038
  _keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head\.weight"]
@@ -1048,7 +1046,7 @@ class LSGBartForConditionalGeneration(BartForConditionalGeneration, LSGBartPretr
1048
  self.post_init()
1049
 
1050
 
1051
- class LSGBartForSequenceClassification(BartForSequenceClassification, LSGBartPretrainedModel):
1052
 
1053
  def __init__(self, config: LSGBartConfig, **kwargs):
1054
 
@@ -1064,7 +1062,7 @@ class LSGBartForSequenceClassification(BartForSequenceClassification, LSGBartPre
1064
  self.model._init_weights(self.classification_head.out_proj)
1065
 
1066
 
1067
- class LSGBartForQuestionAnswering(BartForQuestionAnswering, LSGBartPretrainedModel):
1068
 
1069
  def __init__(self, config: LSGBartConfig):
1070
 
@@ -1093,7 +1091,7 @@ class LSGBartDecoderWrapper(LSGBartPretrainedModel):
1093
  return self.decoder(*args, **kwargs)
1094
 
1095
 
1096
- class LSGBartForCausalLM(BartForCausalLM, LSGBartPretrainedModel):
1097
 
1098
  def __init__(self, config: LSGBartConfig):
1099
 
 
3
  from transformers.models.bart.modeling_bart import *
4
  from transformers.models.bart.modeling_bart import _expand_mask
5
  import torch.nn as nn
 
6
  import sys
7
 
8
  AUTO_MAP = {
 
15
 
16
  class LSGBartConfig(BartConfig):
17
  """
18
+ This class overrides :class:`~transformers.BartConfig`. Please check the superclass for the appropriate
19
  documentation alongside usage examples.
20
  """
21
 
 
265
  s = (size - step) // 2
266
 
267
  # Pad before block reshaping
268
+ if is_attn_mask:
269
+ pad_value = torch.finfo(hidden_states.dtype).min
270
  hidden_states = hidden_states.transpose(-1, -2)
271
  else:
272
  pad_value = 0
 
295
 
296
  # Pad before block reshaping
297
  if is_attn_mask:
298
+ pad_value = torch.finfo(hidden_states.dtype).min
299
  hidden_states = hidden_states.transpose(-1, -2)
300
  else:
301
  pad_value = 0
 
424
  keys = keys.sum(dim=-2) / (mask + 1e-6)
425
  values = values.sum(dim=-2) / (mask + 1e-6)
426
 
427
+ mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
428
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
429
 
430
  def get_sparse_tokens_with_stride(self, keys, values, mask):
 
489
  keys /= mask + 1e-8
490
  values /= mask + 1e-8
491
 
492
+ mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
 
493
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
494
 
495
  def lsh_round(self, keys, values, mask, output_size):
 
737
  n, t = inputs_.size()[:2]
738
 
739
  if attention_mask is None:
740
+ attention_mask = torch.ones(n, t, device=inputs_.device, dtype=inputs_.dtype)
741
  if self.mask_first_token:
742
  attention_mask[:, 0] = 0
743
 
 
889
  )
890
 
891
 
892
+ class LSGBartDecoder(LSGBartPretrainedModel, BartDecoder):
893
  """
894
  Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`LSGBartDecoderLayer`
895
  Args:
 
1030
  )
1031
 
1032
 
1033
+ class LSGBartForConditionalGeneration(LSGBartPretrainedModel, BartForConditionalGeneration):
1034
 
1035
  base_model_prefix = "model"
1036
  _keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head\.weight"]
 
1046
  self.post_init()
1047
 
1048
 
1049
+ class LSGBartForSequenceClassification(LSGBartPretrainedModel, BartForSequenceClassification):
1050
 
1051
  def __init__(self, config: LSGBartConfig, **kwargs):
1052
 
 
1062
  self.model._init_weights(self.classification_head.out_proj)
1063
 
1064
 
1065
+ class LSGBartForQuestionAnswering(LSGBartPretrainedModel, BartForQuestionAnswering):
1066
 
1067
  def __init__(self, config: LSGBartConfig):
1068
 
 
1091
  return self.decoder(*args, **kwargs)
1092
 
1093
 
1094
+ class LSGBartForCausalLM(LSGBartPretrainedModel, BartForCausalLM):
1095
 
1096
  def __init__(self, config: LSGBartConfig):
1097