ccdv commited on
Commit
5944198
1 Parent(s): 5651a5c

add mask_first_token

Browse files
Files changed (3) hide show
  1. README.md +5 -1
  2. config.json +1 -0
  3. modeling_lsg_bart.py +6 -1
README.md CHANGED
@@ -51,13 +51,17 @@ You can change various parameters like :
51
  Default parameters work well in practice. If you are short on memory, reduce block sizes, increase sparsity factor and remove dropout in the attention score matrix.
52
 
53
  ```python:
 
 
54
  model = AutoModel.from_pretrained("ccdv/lsg-bart-large-4096",
55
  trust_remote_code=True,
56
  num_global_tokens=16,
57
  block_size=64,
58
  sparse_block_size=64,
59
- sparsity_factor=4,
60
  attention_probs_dropout_prob=0.0
 
 
 
61
  )
62
  ```
63
 
 
51
  Default parameters work well in practice. If you are short on memory, reduce block sizes, increase sparsity factor and remove dropout in the attention score matrix.
52
 
53
  ```python:
54
+ from transformers import AutoModel
55
+
56
  model = AutoModel.from_pretrained("ccdv/lsg-bart-large-4096",
57
  trust_remote_code=True,
58
  num_global_tokens=16,
59
  block_size=64,
60
  sparse_block_size=64,
 
61
  attention_probs_dropout_prob=0.0
62
+ sparsity_factor=4,
63
+ sparsity_type="none",
64
+ mask_first_token=True
65
  )
66
  ```
67
 
config.json CHANGED
@@ -51,6 +51,7 @@
51
  "LABEL_2": 2
52
  },
53
  "lsh_num_pre_rounds": 1,
 
54
  "max_position_embeddings": 4096,
55
  "model_type": "bart",
56
  "no_repeat_ngram_size": 3,
 
51
  "LABEL_2": 2
52
  },
53
  "lsh_num_pre_rounds": 1,
54
+ "mask_first_token": false,
55
  "max_position_embeddings": 4096,
56
  "model_type": "bart",
57
  "no_repeat_ngram_size": 3,
modeling_lsg_bart.py CHANGED
@@ -31,6 +31,7 @@ class LSGBartConfig(BartConfig):
31
  base_model_prefix="lsg",
32
  block_size=128,
33
  lsh_num_pre_rounds=1,
 
34
  num_global_tokens=1,
35
  pass_global_tokens_to_decoder=True,
36
  pool_with_global=True,
@@ -47,6 +48,7 @@ class LSGBartConfig(BartConfig):
47
  self.base_model_prefix = base_model_prefix
48
  self.block_size = block_size
49
  self.lsh_num_pre_rounds = lsh_num_pre_rounds
 
50
  self.num_global_tokens = num_global_tokens
51
  self.pass_global_tokens_to_decoder = pass_global_tokens_to_decoder
52
  self.pool_with_global = pool_with_global
@@ -711,6 +713,7 @@ class LSGBartEncoder(LSGBartPretrainedModel, BartEncoder):
711
  assert hasattr(config, "block_size") and hasattr(config, "adaptive")
712
  self.block_size = config.block_size
713
  self.adaptive = config.adaptive
 
714
  self.pool_with_global = config.pool_with_global
715
  self.pass_global_tokens_to_decoder = config.pass_global_tokens_to_decoder
716
 
@@ -737,7 +740,9 @@ class LSGBartEncoder(LSGBartPretrainedModel, BartEncoder):
737
 
738
  if attention_mask is None:
739
  attention_mask = torch.ones(n, t, device=inputs_.device)
740
-
 
 
741
  b = self.block_size * 2
742
  pad = t % self.block_size
743
 
 
31
  base_model_prefix="lsg",
32
  block_size=128,
33
  lsh_num_pre_rounds=1,
34
+ mask_first_token=False,
35
  num_global_tokens=1,
36
  pass_global_tokens_to_decoder=True,
37
  pool_with_global=True,
 
48
  self.base_model_prefix = base_model_prefix
49
  self.block_size = block_size
50
  self.lsh_num_pre_rounds = lsh_num_pre_rounds
51
+ self.mask_first_token = mask_first_token
52
  self.num_global_tokens = num_global_tokens
53
  self.pass_global_tokens_to_decoder = pass_global_tokens_to_decoder
54
  self.pool_with_global = pool_with_global
 
713
  assert hasattr(config, "block_size") and hasattr(config, "adaptive")
714
  self.block_size = config.block_size
715
  self.adaptive = config.adaptive
716
+ self.mask_first_token = config.mask_first_token
717
  self.pool_with_global = config.pool_with_global
718
  self.pass_global_tokens_to_decoder = config.pass_global_tokens_to_decoder
719
 
 
740
 
741
  if attention_mask is None:
742
  attention_mask = torch.ones(n, t, device=inputs_.device)
743
+ if self.mask_first_token:
744
+ attention_mask[:, 0] = 0
745
+
746
  b = self.block_size * 2
747
  pad = t % self.block_size
748