ccdv commited on
Commit
73a3f59
1 Parent(s): 89b8bf9

add mask_first_token

Browse files
Files changed (3) hide show
  1. README.md +5 -1
  2. config.json +1 -0
  3. modeling_lsg_bert.py +5 -0
README.md CHANGED
@@ -52,13 +52,17 @@ You can change various parameters like :
52
  Default parameters work well in practice. If you are short on memory, reduce block sizes, increase sparsity factor and remove dropout in the attention score matrix.
53
 
54
  ```python:
 
 
55
  model = AutoModel.from_pretrained("ccdv/legal-lsg-small-uncased-4096",
56
  trust_remote_code=True,
57
  num_global_tokens=16,
58
  block_size=64,
59
  sparse_block_size=64,
60
- sparsity_factor=4,
61
  attention_probs_dropout_prob=0.0
 
 
 
62
  )
63
  ```
64
 
52
  Default parameters work well in practice. If you are short on memory, reduce block sizes, increase sparsity factor and remove dropout in the attention score matrix.
53
 
54
  ```python:
55
+ from transformers import AutoModel
56
+
57
  model = AutoModel.from_pretrained("ccdv/legal-lsg-small-uncased-4096",
58
  trust_remote_code=True,
59
  num_global_tokens=16,
60
  block_size=64,
61
  sparse_block_size=64,
 
62
  attention_probs_dropout_prob=0.0
63
+ sparsity_factor=4,
64
+ sparsity_type="none",
65
+ mask_first_token=True
66
  )
67
  ```
68
 
config.json CHANGED
@@ -28,6 +28,7 @@
28
  "intermediate_size": 2048,
29
  "layer_norm_eps": 1e-12,
30
  "lsh_num_pre_rounds": 1,
 
31
  "max_position_embeddings": 4096,
32
  "model_type": "bert",
33
  "num_attention_heads": 8,
28
  "intermediate_size": 2048,
29
  "layer_norm_eps": 1e-12,
30
  "lsh_num_pre_rounds": 1,
31
+ "mask_first_token": false,
32
  "max_position_embeddings": 4096,
33
  "model_type": "bert",
34
  "num_attention_heads": 8,
modeling_lsg_bert.py CHANGED
@@ -31,6 +31,7 @@ class LSGBertConfig(BertConfig):
31
  base_model_prefix="lsg",
32
  block_size=128,
33
  lsh_num_pre_rounds=1,
 
34
  num_global_tokens=1,
35
  pool_with_global=True,
36
  sparse_block_size=128,
@@ -46,6 +47,7 @@ class LSGBertConfig(BertConfig):
46
  self.base_model_prefix = base_model_prefix
47
  self.block_size = block_size
48
  self.lsh_num_pre_rounds = lsh_num_pre_rounds
 
49
  self.num_global_tokens = num_global_tokens
50
  self.pool_with_global = pool_with_global
51
  self.sparse_block_size = sparse_block_size
@@ -1004,6 +1006,7 @@ class LSGBertModel(LSGBertPreTrainedModel, BertModel):
1004
  assert hasattr(config, "block_size") and hasattr(config, "adaptive")
1005
  self.block_size = config.block_size
1006
  self.adaptive = config.adaptive
 
1007
  self.pool_with_global = config.pool_with_global
1008
 
1009
  self.embeddings = LSGBertEmbeddings(config)
@@ -1040,6 +1043,8 @@ class LSGBertModel(LSGBertPreTrainedModel, BertModel):
1040
 
1041
  if attention_mask is None:
1042
  attention_mask = torch.ones(n, t, device=inputs_.device)
 
 
1043
  if token_type_ids is None:
1044
  token_type_ids = torch.zeros(n, t, device=inputs_.device).long()
1045
 
31
  base_model_prefix="lsg",
32
  block_size=128,
33
  lsh_num_pre_rounds=1,
34
+ mask_first_token=False,
35
  num_global_tokens=1,
36
  pool_with_global=True,
37
  sparse_block_size=128,
47
  self.base_model_prefix = base_model_prefix
48
  self.block_size = block_size
49
  self.lsh_num_pre_rounds = lsh_num_pre_rounds
50
+ self.mask_first_token = mask_first_token
51
  self.num_global_tokens = num_global_tokens
52
  self.pool_with_global = pool_with_global
53
  self.sparse_block_size = sparse_block_size
1006
  assert hasattr(config, "block_size") and hasattr(config, "adaptive")
1007
  self.block_size = config.block_size
1008
  self.adaptive = config.adaptive
1009
+ self.mask_first_token = config.mask_first_token
1010
  self.pool_with_global = config.pool_with_global
1011
 
1012
  self.embeddings = LSGBertEmbeddings(config)
1043
 
1044
  if attention_mask is None:
1045
  attention_mask = torch.ones(n, t, device=inputs_.device)
1046
+ if self.mask_first_token:
1047
+ attention_mask[:,0] = 0
1048
  if token_type_ids is None:
1049
  token_type_ids = torch.zeros(n, t, device=inputs_.device).long()
1050