ccdv commited on
Commit
de93d26
1 Parent(s): 43c3663

bos_token + readme

Browse files
Files changed (2) hide show
  1. README.md +12 -7
  2. modeling_lsg_bert.py +39 -12
README.md CHANGED
@@ -72,26 +72,31 @@ model = AutoModel.from_pretrained("ccdv/legal-lsg-small-uncased-4096",
72
 
73
  ## Sparse selection type
74
 
75
- There are 5 different sparse selection patterns. The best type is task dependent. \
 
76
  Note that for sequences with length < 2*block_size, the type has no effect.
77
-
78
- * sparsity_type="norm", select highest norm tokens
 
 
 
 
79
  * Works best for a small sparsity_factor (2 to 4)
80
  * Additional parameters:
81
  * None
82
- * sparsity_type="pooling", use average pooling to merge tokens
83
  * Works best for a small sparsity_factor (2 to 4)
84
  * Additional parameters:
85
  * None
86
- * sparsity_type="lsh", use the LSH algorithm to cluster similar tokens
87
  * Works best for a large sparsity_factor (4+)
88
  * LSH relies on random projections, thus inference may differ slightly with different seeds
89
  * Additional parameters:
90
  * lsg_num_pre_rounds=1, pre merge tokens n times before computing centroids
91
- * sparsity_type="stride", use a striding mecanism per head
92
  * Each head will use different tokens strided by sparsify_factor
93
  * Not recommended if sparsify_factor > num_heads
94
- * sparsity_type="block_stride", use a striding mecanism per head
95
  * Each head will use block of tokens strided by sparsify_factor
96
  * Not recommended if sparsify_factor > num_heads
97
 
 
72
 
73
  ## Sparse selection type
74
 
75
+ There are 6 different sparse selection patterns. The best type is task dependent. \
76
+ If `sparse_block_size=0` or `sparsity_type="none"`, only local attention is considered. \
77
  Note that for sequences with length < 2*block_size, the type has no effect.
78
+ * `sparsity_type="bos_pooling"` (new)
79
+ * weighted average pooling using the BOS token
80
+ * Works best in general, especially with a rather large sparsity_factor (8, 16, 32)
81
+ * Additional parameters:
82
+ * None
83
+ * `sparsity_type="norm"`, select highest norm tokens
84
  * Works best for a small sparsity_factor (2 to 4)
85
  * Additional parameters:
86
  * None
87
+ * `sparsity_type="pooling"`, use average pooling to merge tokens
88
  * Works best for a small sparsity_factor (2 to 4)
89
  * Additional parameters:
90
  * None
91
+ * `sparsity_type="lsh"`, use the LSH algorithm to cluster similar tokens
92
  * Works best for a large sparsity_factor (4+)
93
  * LSH relies on random projections, thus inference may differ slightly with different seeds
94
  * Additional parameters:
95
  * lsg_num_pre_rounds=1, pre merge tokens n times before computing centroids
96
+ * `sparsity_type="stride"`, use a striding mecanism per head
97
  * Each head will use different tokens strided by sparsify_factor
98
  * Not recommended if sparsify_factor > num_heads
99
+ * `sparsity_type="block_stride"`, use a striding mecanism per head
100
  * Each head will use block of tokens strided by sparsify_factor
101
  * Not recommended if sparsify_factor > num_heads
102
 
modeling_lsg_bert.py CHANGED
@@ -54,16 +54,16 @@ class LSGBertConfig(BertConfig):
54
  self.sparsity_factor = sparsity_factor
55
  self.sparsity_type = sparsity_type
56
 
57
- if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride"]:
58
  logger.warning(
59
- "[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride'], \
60
  setting sparsity_type=None, computation will skip sparse attention")
61
  self.sparsity_type = None
62
 
63
  if self.sparsity_type in ["stride", "block_stride"]:
64
- if self.sparsity_factor > self.encoder_attention_heads:
65
  logger.warning(
66
- "[WARNING CONFIG]: sparsity_factor > encoder_attention_heads is not recommended for stride/block_stride sparsity"
67
  )
68
 
69
  if self.num_global_tokens < 1:
@@ -491,15 +491,16 @@ class LSGSelfAttention(BaseSelfAttention):
491
  "lsh": self.get_sparse_tokens_with_lsh,
492
  "stride": self.get_sparse_tokens_with_stride,
493
  "block_stride": self.get_sparse_tokens_with_block_stride,
 
494
  }
495
 
496
  self.sparsity_type = config.sparsity_type
497
- self.get_sparse_elements = sparse_functions.get(self.sparsity_type, lambda x, y, z: (None, None, None))
498
 
499
  if config.sparsity_type == "lsh":
500
  self.lsh_num_pre_rounds = config.lsh_num_pre_rounds
501
 
502
- def get_sparse_tokens_with_norm(self, keys, values, mask):
503
 
504
  if self.sparsity_factor == 1:
505
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
@@ -527,7 +528,7 @@ class LSGSelfAttention(BaseSelfAttention):
527
 
528
  return keys, values, mask
529
 
530
- def get_sparse_tokens_with_pooling(self, keys, values, mask):
531
 
532
  if self.sparsity_factor == 1:
533
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
@@ -550,7 +551,7 @@ class LSGSelfAttention(BaseSelfAttention):
550
  mask *= torch.finfo(mask.dtype).min
551
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
552
 
553
- def get_sparse_tokens_with_stride(self, keys, values, mask):
554
 
555
  if self.sparsity_factor == 1:
556
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
@@ -566,7 +567,7 @@ class LSGSelfAttention(BaseSelfAttention):
566
 
567
  return keys, values, mask
568
 
569
- def get_sparse_tokens_with_block_stride(self, keys, values, mask):
570
 
571
  if self.sparsity_factor == 1:
572
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
@@ -586,10 +587,13 @@ class LSGSelfAttention(BaseSelfAttention):
586
 
587
  return keys, values, mask
588
 
589
- def get_sparse_tokens_with_lsh(self, keys, values, mask):
590
 
591
  if self.sparsity_factor == 1:
592
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
 
 
 
593
 
594
  block_size = min(self.block_size, self.sparse_block_size)
595
  keys = self.chunk(keys, block_size)
@@ -638,6 +642,29 @@ class LSGSelfAttention(BaseSelfAttention):
638
 
639
  return keys[..., :output_size, :], values[..., :output_size, :], mask[..., :output_size, :]
640
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
641
  def forward(
642
  self,
643
  hidden_states,
@@ -757,7 +784,7 @@ class LSGSelfAttention(BaseSelfAttention):
757
  # Get sparse idx
758
  sparse_key, sparse_value, sparse_mask = (None, None, None)
759
  if self.sparse_block_size and self.sparsity_factor > 0:
760
- sparse_key, sparse_value, sparse_mask = self.get_sparse_elements(key_layer, value_layer, attention_mask)
761
 
762
  # Expand masks on heads
763
  attention_mask = attention_mask.expand(-1, h, -1, -1)
@@ -830,7 +857,7 @@ class LSGSelfAttention(BaseSelfAttention):
830
  sparse_key, sparse_value, sparse_mask = (None, None, None)
831
 
832
  if self.sparse_block_size and self.sparsity_factor > 0:
833
- sparse_key, sparse_value, sparse_mask = self.get_sparse_elements(key_layer, value_layer, attention_mask)
834
 
835
  # Expand masks on heads
836
  attention_mask = attention_mask.expand(-1, h, -1, -1)
 
54
  self.sparsity_factor = sparsity_factor
55
  self.sparsity_type = sparsity_type
56
 
57
+ if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride", "bos_pooling"]:
58
  logger.warning(
59
+ "[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride', 'bos_pooling'], \
60
  setting sparsity_type=None, computation will skip sparse attention")
61
  self.sparsity_type = None
62
 
63
  if self.sparsity_type in ["stride", "block_stride"]:
64
+ if self.sparsity_factor > self.num_attention_heads:
65
  logger.warning(
66
+ "[WARNING CONFIG]: sparsity_factor > num_attention_heads is not recommended for stride/block_stride sparsity"
67
  )
68
 
69
  if self.num_global_tokens < 1:
 
491
  "lsh": self.get_sparse_tokens_with_lsh,
492
  "stride": self.get_sparse_tokens_with_stride,
493
  "block_stride": self.get_sparse_tokens_with_block_stride,
494
+ "bos_pooling": self.get_sparse_tokens_with_bos_pooling
495
  }
496
 
497
  self.sparsity_type = config.sparsity_type
498
+ self.get_sparse_elements = sparse_functions.get(self.sparsity_type, lambda w, x, y, z: (None, None, None))
499
 
500
  if config.sparsity_type == "lsh":
501
  self.lsh_num_pre_rounds = config.lsh_num_pre_rounds
502
 
503
+ def get_sparse_tokens_with_norm(self, queries, keys, values, mask):
504
 
505
  if self.sparsity_factor == 1:
506
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
 
528
 
529
  return keys, values, mask
530
 
531
+ def get_sparse_tokens_with_pooling(self, queries, keys, values, mask):
532
 
533
  if self.sparsity_factor == 1:
534
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
 
551
  mask *= torch.finfo(mask.dtype).min
552
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
553
 
554
+ def get_sparse_tokens_with_stride(self, queries, keys, values, mask):
555
 
556
  if self.sparsity_factor == 1:
557
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
 
567
 
568
  return keys, values, mask
569
 
570
+ def get_sparse_tokens_with_block_stride(self, queries, keys, values, mask):
571
 
572
  if self.sparsity_factor == 1:
573
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
 
587
 
588
  return keys, values, mask
589
 
590
+ def get_sparse_tokens_with_lsh(self, queries, keys, values, mask):
591
 
592
  if self.sparsity_factor == 1:
593
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
594
+
595
+ if self.sparsity_factor == self.sparse_block_size:
596
+ return self.get_sparse_tokens_with_bos_pooling(queries, keys, values, mask)
597
 
598
  block_size = min(self.block_size, self.sparse_block_size)
599
  keys = self.chunk(keys, block_size)
 
642
 
643
  return keys[..., :output_size, :], values[..., :output_size, :], mask[..., :output_size, :]
644
 
645
+ def get_sparse_tokens_with_bos_pooling(self, queries, keys, values, mask):
646
+
647
+ if self.sparsity_factor == 1:
648
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
649
+
650
+ queries = queries.unsqueeze(-3)
651
+ mask = self.chunk(mask.transpose(-1, -2), self.sparsity_factor).transpose(-1, -2)
652
+ keys = self.chunk(keys, self.sparsity_factor)
653
+ values = self.chunk(values, self.sparsity_factor)
654
+
655
+ n, h, b, t, d = keys.size()
656
+ scores = (queries[..., :1, :] @ keys.transpose(-1, -2)) / math.sqrt(d)
657
+ if mask is not None:
658
+ scores = scores + mask
659
+
660
+ scores = torch.softmax(scores, dim=-1)
661
+ keys = scores @ keys
662
+ values = scores @ values
663
+ mask = mask.mean(dim=-1)
664
+ mask[mask != torch.finfo(mask.dtype).min] = 0
665
+
666
+ return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
667
+
668
  def forward(
669
  self,
670
  hidden_states,
 
784
  # Get sparse idx
785
  sparse_key, sparse_value, sparse_mask = (None, None, None)
786
  if self.sparse_block_size and self.sparsity_factor > 0:
787
+ sparse_key, sparse_value, sparse_mask = self.get_sparse_elements(query_layer, key_layer, value_layer, attention_mask)
788
 
789
  # Expand masks on heads
790
  attention_mask = attention_mask.expand(-1, h, -1, -1)
 
857
  sparse_key, sparse_value, sparse_mask = (None, None, None)
858
 
859
  if self.sparse_block_size and self.sparsity_factor > 0:
860
+ sparse_key, sparse_value, sparse_mask = self.get_sparse_elements(query_layer, key_layer, value_layer, attention_mask)
861
 
862
  # Expand masks on heads
863
  attention_mask = attention_mask.expand(-1, h, -1, -1)