ccdv commited on
Commit
1d9841a
1 Parent(s): 93a4f16

bos_token + readme

Browse files
Files changed (2) hide show
  1. README.md +12 -7
  2. modeling_lsg_distilbert.py +38 -11
README.md CHANGED
@@ -70,26 +70,31 @@ model = AutoModel.from_pretrained("ccdv/lsg-distilbert-base-uncased-4096",
70
 
71
  ## Sparse selection type
72
 
73
- There are 5 different sparse selection patterns. The best type is task dependent. \
 
74
  Note that for sequences with length < 2*block_size, the type has no effect.
75
-
76
- * sparsity_type="norm", select highest norm tokens
 
 
 
 
77
  * Works best for a small sparsity_factor (2 to 4)
78
  * Additional parameters:
79
  * None
80
- * sparsity_type="pooling", use average pooling to merge tokens
81
  * Works best for a small sparsity_factor (2 to 4)
82
  * Additional parameters:
83
  * None
84
- * sparsity_type="lsh", use the LSH algorithm to cluster similar tokens
85
  * Works best for a large sparsity_factor (4+)
86
  * LSH relies on random projections, thus inference may differ slightly with different seeds
87
  * Additional parameters:
88
  * lsg_num_pre_rounds=1, pre merge tokens n times before computing centroids
89
- * sparsity_type="stride", use a striding mecanism per head
90
  * Each head will use different tokens strided by sparsify_factor
91
  * Not recommended if sparsify_factor > num_heads
92
- * sparsity_type="block_stride", use a striding mecanism per head
93
  * Each head will use block of tokens strided by sparsify_factor
94
  * Not recommended if sparsify_factor > num_heads
95
 
 
70
 
71
  ## Sparse selection type
72
 
73
+ There are 6 different sparse selection patterns. The best type is task dependent. \
74
+ If `sparse_block_size=0` or `sparsity_type="none"`, only local attention is considered. \
75
  Note that for sequences with length < 2*block_size, the type has no effect.
76
+ * `sparsity_type="bos_pooling"` (new)
77
+ * weighted average pooling using the BOS token
78
+ * Works best in general, especially with a rather large sparsity_factor (8, 16, 32)
79
+ * Additional parameters:
80
+ * None
81
+ * `sparsity_type="norm"`, select highest norm tokens
82
  * Works best for a small sparsity_factor (2 to 4)
83
  * Additional parameters:
84
  * None
85
+ * `sparsity_type="pooling"`, use average pooling to merge tokens
86
  * Works best for a small sparsity_factor (2 to 4)
87
  * Additional parameters:
88
  * None
89
+ * `sparsity_type="lsh"`, use the LSH algorithm to cluster similar tokens
90
  * Works best for a large sparsity_factor (4+)
91
  * LSH relies on random projections, thus inference may differ slightly with different seeds
92
  * Additional parameters:
93
  * lsg_num_pre_rounds=1, pre merge tokens n times before computing centroids
94
+ * `sparsity_type="stride"`, use a striding mecanism per head
95
  * Each head will use different tokens strided by sparsify_factor
96
  * Not recommended if sparsify_factor > num_heads
97
+ * `sparsity_type="block_stride"`, use a striding mecanism per head
98
  * Each head will use block of tokens strided by sparsify_factor
99
  * Not recommended if sparsify_factor > num_heads
100
 
modeling_lsg_distilbert.py CHANGED
@@ -50,16 +50,16 @@ class LSGDistilBertConfig(DistilBertConfig):
50
  self.sparsity_factor = sparsity_factor
51
  self.sparsity_type = sparsity_type
52
 
53
- if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride"]:
54
  logger.warning(
55
- "[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride'], \
56
  setting sparsity_type=None, computation will skip sparse attention")
57
  self.sparsity_type = None
58
 
59
  if self.sparsity_type in ["stride", "block_stride"]:
60
- if self.sparsity_factor > self.encoder_attention_heads:
61
  logger.warning(
62
- "[WARNING CONFIG]: sparsity_factor > encoder_attention_heads is not recommended for stride/block_stride sparsity"
63
  )
64
 
65
  if self.num_global_tokens < 1:
@@ -477,15 +477,16 @@ class LSGSelfAttention(BaseSelfAttention):
477
  "lsh": self.get_sparse_tokens_with_lsh,
478
  "stride": self.get_sparse_tokens_with_stride,
479
  "block_stride": self.get_sparse_tokens_with_block_stride,
 
480
  }
481
 
482
  self.sparsity_type = config.sparsity_type
483
- self.get_sparse_elements = sparse_functions.get(self.sparsity_type, lambda x, y, z: (None, None, None))
484
 
485
  if config.sparsity_type == "lsh":
486
  self.lsh_num_pre_rounds = config.lsh_num_pre_rounds
487
 
488
- def get_sparse_tokens_with_norm(self, keys, values, mask):
489
 
490
  if self.sparsity_factor == 1:
491
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
@@ -513,7 +514,7 @@ class LSGSelfAttention(BaseSelfAttention):
513
 
514
  return keys, values, mask
515
 
516
- def get_sparse_tokens_with_pooling(self, keys, values, mask):
517
 
518
  if self.sparsity_factor == 1:
519
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
@@ -536,7 +537,7 @@ class LSGSelfAttention(BaseSelfAttention):
536
  mask *= torch.finfo(mask.dtype).min
537
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
538
 
539
- def get_sparse_tokens_with_stride(self, keys, values, mask):
540
 
541
  if self.sparsity_factor == 1:
542
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
@@ -552,7 +553,7 @@ class LSGSelfAttention(BaseSelfAttention):
552
 
553
  return keys, values, mask
554
 
555
- def get_sparse_tokens_with_block_stride(self, keys, values, mask):
556
 
557
  if self.sparsity_factor == 1:
558
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
@@ -572,11 +573,14 @@ class LSGSelfAttention(BaseSelfAttention):
572
 
573
  return keys, values, mask
574
 
575
- def get_sparse_tokens_with_lsh(self, keys, values, mask):
576
 
577
  if self.sparsity_factor == 1:
578
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
579
 
 
 
 
580
  block_size = min(self.block_size, self.sparse_block_size)
581
  keys = self.chunk(keys, block_size)
582
  values = self.chunk(values, block_size)
@@ -624,6 +628,29 @@ class LSGSelfAttention(BaseSelfAttention):
624
 
625
  return keys[..., :output_size, :], values[..., :output_size, :], mask[..., :output_size, :]
626
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
627
  def forward(
628
  self,
629
  query,
@@ -695,7 +722,7 @@ class LSGSelfAttention(BaseSelfAttention):
695
  sparse_key, sparse_value, sparse_mask = (None, None, None)
696
 
697
  if self.sparse_block_size and self.sparsity_factor > 0:
698
- sparse_key, sparse_value, sparse_mask = self.get_sparse_elements(key_layer, value_layer, attention_mask)
699
 
700
  # Expand masks on heads
701
  attention_mask = attention_mask.expand(-1, h, -1, -1)
 
50
  self.sparsity_factor = sparsity_factor
51
  self.sparsity_type = sparsity_type
52
 
53
+ if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride", "bos_pooling"]:
54
  logger.warning(
55
+ "[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride', 'bos_pooling'], \
56
  setting sparsity_type=None, computation will skip sparse attention")
57
  self.sparsity_type = None
58
 
59
  if self.sparsity_type in ["stride", "block_stride"]:
60
+ if self.sparsity_factor > self.n_heads:
61
  logger.warning(
62
+ "[WARNING CONFIG]: sparsity_factor > n_heads is not recommended for stride/block_stride sparsity"
63
  )
64
 
65
  if self.num_global_tokens < 1:
 
477
  "lsh": self.get_sparse_tokens_with_lsh,
478
  "stride": self.get_sparse_tokens_with_stride,
479
  "block_stride": self.get_sparse_tokens_with_block_stride,
480
+ "bos_pooling": self.get_sparse_tokens_with_bos_pooling
481
  }
482
 
483
  self.sparsity_type = config.sparsity_type
484
+ self.get_sparse_elements = sparse_functions.get(self.sparsity_type, lambda w, x, y, z: (None, None, None))
485
 
486
  if config.sparsity_type == "lsh":
487
  self.lsh_num_pre_rounds = config.lsh_num_pre_rounds
488
 
489
+ def get_sparse_tokens_with_norm(self, queries, keys, values, mask):
490
 
491
  if self.sparsity_factor == 1:
492
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
 
514
 
515
  return keys, values, mask
516
 
517
+ def get_sparse_tokens_with_pooling(self, queries, keys, values, mask):
518
 
519
  if self.sparsity_factor == 1:
520
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
 
537
  mask *= torch.finfo(mask.dtype).min
538
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
539
 
540
+ def get_sparse_tokens_with_stride(self, queries, keys, values, mask):
541
 
542
  if self.sparsity_factor == 1:
543
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
 
553
 
554
  return keys, values, mask
555
 
556
+ def get_sparse_tokens_with_block_stride(self, queries, keys, values, mask):
557
 
558
  if self.sparsity_factor == 1:
559
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
 
573
 
574
  return keys, values, mask
575
 
576
+ def get_sparse_tokens_with_lsh(self, queries, keys, values, mask):
577
 
578
  if self.sparsity_factor == 1:
579
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
580
 
581
+ if self.sparsity_factor == self.sparse_block_size:
582
+ return self.get_sparse_tokens_with_bos_pooling(queries, keys, values, mask)
583
+
584
  block_size = min(self.block_size, self.sparse_block_size)
585
  keys = self.chunk(keys, block_size)
586
  values = self.chunk(values, block_size)
 
628
 
629
  return keys[..., :output_size, :], values[..., :output_size, :], mask[..., :output_size, :]
630
 
631
+ def get_sparse_tokens_with_bos_pooling(self, queries, keys, values, mask):
632
+
633
+ if self.sparsity_factor == 1:
634
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
635
+
636
+ queries = queries.unsqueeze(-3)
637
+ mask = self.chunk(mask.transpose(-1, -2), self.sparsity_factor).transpose(-1, -2)
638
+ keys = self.chunk(keys, self.sparsity_factor)
639
+ values = self.chunk(values, self.sparsity_factor)
640
+
641
+ n, h, b, t, d = keys.size()
642
+ scores = (queries[..., :1, :] @ keys.transpose(-1, -2)) / math.sqrt(d)
643
+ if mask is not None:
644
+ scores = scores + mask
645
+
646
+ scores = torch.softmax(scores, dim=-1)
647
+ keys = scores @ keys
648
+ values = scores @ values
649
+ mask = mask.mean(dim=-1)
650
+ mask[mask != torch.finfo(mask.dtype).min] = 0
651
+
652
+ return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
653
+
654
  def forward(
655
  self,
656
  query,
 
722
  sparse_key, sparse_value, sparse_mask = (None, None, None)
723
 
724
  if self.sparse_block_size and self.sparsity_factor > 0:
725
+ sparse_key, sparse_value, sparse_mask = self.get_sparse_elements(query_layer, key_layer, value_layer, attention_mask)
726
 
727
  # Expand masks on heads
728
  attention_mask = attention_mask.expand(-1, h, -1, -1)