ccdv commited on
Commit
be0d373
1 Parent(s): a5f1f16

bos_token + readme

Browse files
Files changed (2) hide show
  1. README.md +12 -7
  2. modeling_lsg_xlm_roberta.py +39 -12
README.md CHANGED
@@ -69,26 +69,31 @@ model = AutoModel.from_pretrained("ccdv/lsg-xlm-roberta-base-4096",
69
 
70
  ## Sparse selection type
71
 
72
- There are 5 different sparse selection patterns. The best type is task dependent. \
 
73
  Note that for sequences with length < 2*block_size, the type has no effect.
74
-
75
- * sparsity_type="norm", select highest norm tokens
 
 
 
 
76
  * Works best for a small sparsity_factor (2 to 4)
77
  * Additional parameters:
78
  * None
79
- * sparsity_type="pooling", use average pooling to merge tokens
80
  * Works best for a small sparsity_factor (2 to 4)
81
  * Additional parameters:
82
  * None
83
- * sparsity_type="lsh", use the LSH algorithm to cluster similar tokens
84
  * Works best for a large sparsity_factor (4+)
85
  * LSH relies on random projections, thus inference may differ slightly with different seeds
86
  * Additional parameters:
87
  * lsg_num_pre_rounds=1, pre merge tokens n times before computing centroids
88
- * sparsity_type="stride", use a striding mecanism per head
89
  * Each head will use different tokens strided by sparsify_factor
90
  * Not recommended if sparsify_factor > num_heads
91
- * sparsity_type="block_stride", use a striding mecanism per head
92
  * Each head will use block of tokens strided by sparsify_factor
93
  * Not recommended if sparsify_factor > num_heads
94
 
 
69
 
70
  ## Sparse selection type
71
 
72
+ There are 6 different sparse selection patterns. The best type is task dependent. \
73
+ If `sparse_block_size=0` or `sparsity_type="none"`, only local attention is considered. \
74
  Note that for sequences with length < 2*block_size, the type has no effect.
75
+ * `sparsity_type="bos_pooling"` (new)
76
+ * weighted average pooling using the BOS token
77
+ * Works best in general, especially with a rather large sparsity_factor (8, 16, 32)
78
+ * Additional parameters:
79
+ * None
80
+ * `sparsity_type="norm"`, select highest norm tokens
81
  * Works best for a small sparsity_factor (2 to 4)
82
  * Additional parameters:
83
  * None
84
+ * `sparsity_type="pooling"`, use average pooling to merge tokens
85
  * Works best for a small sparsity_factor (2 to 4)
86
  * Additional parameters:
87
  * None
88
+ * `sparsity_type="lsh"`, use the LSH algorithm to cluster similar tokens
89
  * Works best for a large sparsity_factor (4+)
90
  * LSH relies on random projections, thus inference may differ slightly with different seeds
91
  * Additional parameters:
92
  * lsg_num_pre_rounds=1, pre merge tokens n times before computing centroids
93
+ * `sparsity_type="stride"`, use a striding mecanism per head
94
  * Each head will use different tokens strided by sparsify_factor
95
  * Not recommended if sparsify_factor > num_heads
96
+ * `sparsity_type="block_stride"`, use a striding mecanism per head
97
  * Each head will use block of tokens strided by sparsify_factor
98
  * Not recommended if sparsify_factor > num_heads
99
 
modeling_lsg_xlm_roberta.py CHANGED
@@ -53,16 +53,16 @@ class LSGXLMRobertaConfig(XLMRobertaConfig):
53
  self.sparsity_factor = sparsity_factor
54
  self.sparsity_type = sparsity_type
55
 
56
- if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride"]:
57
  logger.warning(
58
- "[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride'], \
59
  setting sparsity_type=None, computation will skip sparse attention")
60
  self.sparsity_type = None
61
 
62
  if self.sparsity_type in ["stride", "block_stride"]:
63
- if self.sparsity_factor > self.encoder_attention_heads:
64
  logger.warning(
65
- "[WARNING CONFIG]: sparsity_factor > encoder_attention_heads is not recommended for stride/block_stride sparsity"
66
  )
67
 
68
  if self.num_global_tokens < 1:
@@ -497,15 +497,16 @@ class LSGSelfAttention(BaseSelfAttention):
497
  "lsh": self.get_sparse_tokens_with_lsh,
498
  "stride": self.get_sparse_tokens_with_stride,
499
  "block_stride": self.get_sparse_tokens_with_block_stride,
 
500
  }
501
 
502
  self.sparsity_type = config.sparsity_type
503
- self.get_sparse_elements = sparse_functions.get(self.sparsity_type, lambda x, y, z: (None, None, None))
504
 
505
  if config.sparsity_type == "lsh":
506
  self.lsh_num_pre_rounds = config.lsh_num_pre_rounds
507
 
508
- def get_sparse_tokens_with_norm(self, keys, values, mask):
509
 
510
  if self.sparsity_factor == 1:
511
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
@@ -533,7 +534,7 @@ class LSGSelfAttention(BaseSelfAttention):
533
 
534
  return keys, values, mask
535
 
536
- def get_sparse_tokens_with_pooling(self, keys, values, mask):
537
 
538
  if self.sparsity_factor == 1:
539
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
@@ -556,7 +557,7 @@ class LSGSelfAttention(BaseSelfAttention):
556
  mask *= torch.finfo(mask.dtype).min
557
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
558
 
559
- def get_sparse_tokens_with_stride(self, keys, values, mask):
560
 
561
  if self.sparsity_factor == 1:
562
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
@@ -572,7 +573,7 @@ class LSGSelfAttention(BaseSelfAttention):
572
 
573
  return keys, values, mask
574
 
575
- def get_sparse_tokens_with_block_stride(self, keys, values, mask):
576
 
577
  if self.sparsity_factor == 1:
578
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
@@ -592,11 +593,14 @@ class LSGSelfAttention(BaseSelfAttention):
592
 
593
  return keys, values, mask
594
 
595
- def get_sparse_tokens_with_lsh(self, keys, values, mask):
596
 
597
  if self.sparsity_factor == 1:
598
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
599
 
 
 
 
600
  block_size = min(self.block_size, self.sparse_block_size)
601
  keys = self.chunk(keys, block_size)
602
  values = self.chunk(values, block_size)
@@ -644,6 +648,29 @@ class LSGSelfAttention(BaseSelfAttention):
644
 
645
  return keys[..., :output_size, :], values[..., :output_size, :], mask[..., :output_size, :]
646
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
647
  def forward(
648
  self,
649
  hidden_states,
@@ -763,7 +790,7 @@ class LSGSelfAttention(BaseSelfAttention):
763
  # Get sparse idx
764
  sparse_key, sparse_value, sparse_mask = (None, None, None)
765
  if self.sparse_block_size and self.sparsity_factor > 0:
766
- sparse_key, sparse_value, sparse_mask = self.get_sparse_elements(key_layer, value_layer, attention_mask)
767
 
768
  # Expand masks on heads
769
  attention_mask = attention_mask.expand(-1, h, -1, -1)
@@ -836,7 +863,7 @@ class LSGSelfAttention(BaseSelfAttention):
836
  sparse_key, sparse_value, sparse_mask = (None, None, None)
837
 
838
  if self.sparse_block_size and self.sparsity_factor > 0:
839
- sparse_key, sparse_value, sparse_mask = self.get_sparse_elements(key_layer, value_layer, attention_mask)
840
 
841
  # Expand masks on heads
842
  attention_mask = attention_mask.expand(-1, h, -1, -1)
 
53
  self.sparsity_factor = sparsity_factor
54
  self.sparsity_type = sparsity_type
55
 
56
+ if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride", "bos_pooling"]:
57
  logger.warning(
58
+ "[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride', 'bos_pooling'], \
59
  setting sparsity_type=None, computation will skip sparse attention")
60
  self.sparsity_type = None
61
 
62
  if self.sparsity_type in ["stride", "block_stride"]:
63
+ if self.sparsity_factor > self.num_attention_heads:
64
  logger.warning(
65
+ "[WARNING CONFIG]: sparsity_factor > num_attention_heads is not recommended for stride/block_stride sparsity"
66
  )
67
 
68
  if self.num_global_tokens < 1:
 
497
  "lsh": self.get_sparse_tokens_with_lsh,
498
  "stride": self.get_sparse_tokens_with_stride,
499
  "block_stride": self.get_sparse_tokens_with_block_stride,
500
+ "bos_pooling": self.get_sparse_tokens_with_bos_pooling
501
  }
502
 
503
  self.sparsity_type = config.sparsity_type
504
+ self.get_sparse_elements = sparse_functions.get(self.sparsity_type, lambda w, x, y, z: (None, None, None))
505
 
506
  if config.sparsity_type == "lsh":
507
  self.lsh_num_pre_rounds = config.lsh_num_pre_rounds
508
 
509
+ def get_sparse_tokens_with_norm(self, queries, keys, values, mask):
510
 
511
  if self.sparsity_factor == 1:
512
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
 
534
 
535
  return keys, values, mask
536
 
537
+ def get_sparse_tokens_with_pooling(self, queries, keys, values, mask):
538
 
539
  if self.sparsity_factor == 1:
540
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
 
557
  mask *= torch.finfo(mask.dtype).min
558
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
559
 
560
+ def get_sparse_tokens_with_stride(self, queries, keys, values, mask):
561
 
562
  if self.sparsity_factor == 1:
563
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
 
573
 
574
  return keys, values, mask
575
 
576
+ def get_sparse_tokens_with_block_stride(self, queries, keys, values, mask):
577
 
578
  if self.sparsity_factor == 1:
579
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
 
593
 
594
  return keys, values, mask
595
 
596
+ def get_sparse_tokens_with_lsh(self, queries, keys, values, mask):
597
 
598
  if self.sparsity_factor == 1:
599
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
600
 
601
+ if self.sparsity_factor == self.sparse_block_size:
602
+ return self.get_sparse_tokens_with_bos_pooling(queries, keys, values, mask)
603
+
604
  block_size = min(self.block_size, self.sparse_block_size)
605
  keys = self.chunk(keys, block_size)
606
  values = self.chunk(values, block_size)
 
648
 
649
  return keys[..., :output_size, :], values[..., :output_size, :], mask[..., :output_size, :]
650
 
651
+ def get_sparse_tokens_with_bos_pooling(self, queries, keys, values, mask):
652
+
653
+ if self.sparsity_factor == 1:
654
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
655
+
656
+ queries = queries.unsqueeze(-3)
657
+ mask = self.chunk(mask.transpose(-1, -2), self.sparsity_factor).transpose(-1, -2)
658
+ keys = self.chunk(keys, self.sparsity_factor)
659
+ values = self.chunk(values, self.sparsity_factor)
660
+
661
+ n, h, b, t, d = keys.size()
662
+ scores = (queries[..., :1, :] @ keys.transpose(-1, -2)) / math.sqrt(d)
663
+ if mask is not None:
664
+ scores = scores + mask
665
+
666
+ scores = torch.softmax(scores, dim=-1)
667
+ keys = scores @ keys
668
+ values = scores @ values
669
+ mask = mask.mean(dim=-1)
670
+ mask[mask != torch.finfo(mask.dtype).min] = 0
671
+
672
+ return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
673
+
674
  def forward(
675
  self,
676
  hidden_states,
 
790
  # Get sparse idx
791
  sparse_key, sparse_value, sparse_mask = (None, None, None)
792
  if self.sparse_block_size and self.sparsity_factor > 0:
793
+ sparse_key, sparse_value, sparse_mask = self.get_sparse_elements(query_layer, key_layer, value_layer, attention_mask)
794
 
795
  # Expand masks on heads
796
  attention_mask = attention_mask.expand(-1, h, -1, -1)
 
863
  sparse_key, sparse_value, sparse_mask = (None, None, None)
864
 
865
  if self.sparse_block_size and self.sparsity_factor > 0:
866
+ sparse_key, sparse_value, sparse_mask = self.get_sparse_elements(query_layer, key_layer, value_layer, attention_mask)
867
 
868
  # Expand masks on heads
869
  attention_mask = attention_mask.expand(-1, h, -1, -1)