ccdv commited on
Commit
682cd96
1 Parent(s): 1419292

bos_token + readme

Browse files
Files changed (2) hide show
  1. README.md +12 -7
  2. modeling_lsg_pegasus.py +38 -11
README.md CHANGED
@@ -69,26 +69,31 @@ model = AutoModel.from_pretrained("ccdv/lsg-pegasus-large-4096",
69
 
70
  ## Sparse selection type
71
 
72
- There are 5 different sparse selection patterns. The best type is task dependent. \
 
73
  Note that for sequences with length < 2*block_size, the type has no effect.
74
-
75
- * sparsity_type="norm", select highest norm tokens
 
 
 
 
76
  * Works best for a small sparsity_factor (2 to 4)
77
  * Additional parameters:
78
  * None
79
- * sparsity_type="pooling", use average pooling to merge tokens
80
  * Works best for a small sparsity_factor (2 to 4)
81
  * Additional parameters:
82
  * None
83
- * sparsity_type="lsh", use the LSH algorithm to cluster similar tokens
84
  * Works best for a large sparsity_factor (4+)
85
  * LSH relies on random projections, thus inference may differ slightly with different seeds
86
  * Additional parameters:
87
  * lsg_num_pre_rounds=1, pre merge tokens n times before computing centroids
88
- * sparsity_type="stride", use a striding mecanism per head
89
  * Each head will use different tokens strided by sparsify_factor
90
  * Not recommended if sparsify_factor > num_heads
91
- * sparsity_type="block_stride", use a striding mecanism per head
92
  * Each head will use block of tokens strided by sparsify_factor
93
  * Not recommended if sparsify_factor > num_heads
94
 
 
69
 
70
  ## Sparse selection type
71
 
72
+ There are 6 different sparse selection patterns. The best type is task dependent. \
73
+ If `sparse_block_size=0` or `sparsity_type="none"`, only local attention is considered. \
74
  Note that for sequences with length < 2*block_size, the type has no effect.
75
+ * `sparsity_type="bos_pooling"` (new)
76
+ * weighted average pooling using the BOS token
77
+ * Works best in general, especially with a rather large sparsity_factor (8, 16, 32)
78
+ * Additional parameters:
79
+ * None
80
+ * `sparsity_type="norm"`, select highest norm tokens
81
  * Works best for a small sparsity_factor (2 to 4)
82
  * Additional parameters:
83
  * None
84
+ * `sparsity_type="pooling"`, use average pooling to merge tokens
85
  * Works best for a small sparsity_factor (2 to 4)
86
  * Additional parameters:
87
  * None
88
+ * `sparsity_type="lsh"`, use the LSH algorithm to cluster similar tokens
89
  * Works best for a large sparsity_factor (4+)
90
  * LSH relies on random projections, thus inference may differ slightly with different seeds
91
  * Additional parameters:
92
  * lsg_num_pre_rounds=1, pre merge tokens n times before computing centroids
93
+ * `sparsity_type="stride"`, use a striding mecanism per head
94
  * Each head will use different tokens strided by sparsify_factor
95
  * Not recommended if sparsify_factor > num_heads
96
+ * `sparsity_type="block_stride"`, use a striding mecanism per head
97
  * Each head will use block of tokens strided by sparsify_factor
98
  * Not recommended if sparsify_factor > num_heads
99
 
modeling_lsg_pegasus.py CHANGED
@@ -53,9 +53,9 @@ class LSGPegasusConfig(PegasusConfig):
53
  self.sparsity_factor = sparsity_factor
54
  self.sparsity_type = sparsity_type
55
 
56
- if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride"]:
57
  logger.warning(
58
- "[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride'], \
59
  setting sparsity_type=None, computation will skip sparse attention")
60
  self.sparsity_type = None
61
 
@@ -343,7 +343,7 @@ class LSGAttentionProduct(nn.Module):
343
  return x.reshape(*x.size()[:-2], n_blocks, -1, d)
344
 
345
 
346
- class LSGPegasusEncoderAttention(BaseSelfAttention):
347
  '''
348
  Compute local attention with overlapping blocs
349
  Use global attention for tokens with highest norm
@@ -378,15 +378,16 @@ class LSGPegasusEncoderAttention(BaseSelfAttention):
378
  "lsh": self.get_sparse_tokens_with_lsh,
379
  "stride": self.get_sparse_tokens_with_stride,
380
  "block_stride": self.get_sparse_tokens_with_block_stride,
 
381
  }
382
 
383
  self.sparsity_type = config.sparsity_type
384
- self.get_sparse_elements = sparse_functions.get(self.sparsity_type, lambda x, y, z: (None, None, None))
385
 
386
  if config.sparsity_type == "lsh":
387
  self.lsh_num_pre_rounds = config.lsh_num_pre_rounds
388
 
389
- def get_sparse_tokens_with_norm(self, keys, values, mask):
390
 
391
  if self.sparsity_factor == 1:
392
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
@@ -414,7 +415,7 @@ class LSGPegasusEncoderAttention(BaseSelfAttention):
414
 
415
  return keys, values, mask
416
 
417
- def get_sparse_tokens_with_pooling(self, keys, values, mask):
418
 
419
  if self.sparsity_factor == 1:
420
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
@@ -437,7 +438,7 @@ class LSGPegasusEncoderAttention(BaseSelfAttention):
437
  mask *= torch.finfo(mask.dtype).min
438
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
439
 
440
- def get_sparse_tokens_with_stride(self, keys, values, mask):
441
 
442
  if self.sparsity_factor == 1:
443
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
@@ -453,7 +454,7 @@ class LSGPegasusEncoderAttention(BaseSelfAttention):
453
 
454
  return keys, values, mask
455
 
456
- def get_sparse_tokens_with_block_stride(self, keys, values, mask):
457
 
458
  if self.sparsity_factor == 1:
459
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
@@ -473,11 +474,14 @@ class LSGPegasusEncoderAttention(BaseSelfAttention):
473
 
474
  return keys, values, mask
475
 
476
- def get_sparse_tokens_with_lsh(self, keys, values, mask):
477
 
478
  if self.sparsity_factor == 1:
479
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
480
 
 
 
 
481
  block_size = min(self.block_size, self.sparse_block_size)
482
  keys = self.chunk(keys, block_size)
483
  values = self.chunk(values, block_size)
@@ -525,6 +529,29 @@ class LSGPegasusEncoderAttention(BaseSelfAttention):
525
 
526
  return keys[..., :output_size, :], values[..., :output_size, :], mask[..., :output_size, :]
527
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
528
  def forward(
529
  self,
530
  hidden_states,
@@ -594,7 +621,7 @@ class LSGPegasusEncoderAttention(BaseSelfAttention):
594
  sparse_key, sparse_value, sparse_mask = (None, None, None)
595
 
596
  if self.sparse_block_size and self.sparsity_factor > 0:
597
- sparse_key, sparse_value, sparse_mask = self.get_sparse_elements(key_layer, value_layer, attention_mask)
598
 
599
  # Expand masks on heads
600
  attention_mask = attention_mask.expand(-1, h, -1, -1)
@@ -667,7 +694,7 @@ class LSGPegasusEncoderLayer(PegasusEncoderLayer):
667
  def __init__(self, config: LSGPegasusConfig):
668
 
669
  super().__init__(config)
670
- self.self_attn = LSGPegasusEncoderAttention(
671
  config=config,
672
  embed_dim=self.embed_dim,
673
  num_heads=config.encoder_attention_heads,
 
53
  self.sparsity_factor = sparsity_factor
54
  self.sparsity_type = sparsity_type
55
 
56
+ if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride", "bos_pooling"]:
57
  logger.warning(
58
+ "[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride', 'bos_pooling'], \
59
  setting sparsity_type=None, computation will skip sparse attention")
60
  self.sparsity_type = None
61
 
 
343
  return x.reshape(*x.size()[:-2], n_blocks, -1, d)
344
 
345
 
346
+ class LSGPegasusEncoderSelfAttention(BaseSelfAttention):
347
  '''
348
  Compute local attention with overlapping blocs
349
  Use global attention for tokens with highest norm
 
378
  "lsh": self.get_sparse_tokens_with_lsh,
379
  "stride": self.get_sparse_tokens_with_stride,
380
  "block_stride": self.get_sparse_tokens_with_block_stride,
381
+ "bos_pooling": self.get_sparse_tokens_with_bos_pooling
382
  }
383
 
384
  self.sparsity_type = config.sparsity_type
385
+ self.get_sparse_elements = sparse_functions.get(self.sparsity_type, lambda w, x, y, z: (None, None, None))
386
 
387
  if config.sparsity_type == "lsh":
388
  self.lsh_num_pre_rounds = config.lsh_num_pre_rounds
389
 
390
+ def get_sparse_tokens_with_norm(self, queries, keys, values, mask):
391
 
392
  if self.sparsity_factor == 1:
393
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
 
415
 
416
  return keys, values, mask
417
 
418
+ def get_sparse_tokens_with_pooling(self, queries, keys, values, mask):
419
 
420
  if self.sparsity_factor == 1:
421
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
 
438
  mask *= torch.finfo(mask.dtype).min
439
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
440
 
441
+ def get_sparse_tokens_with_stride(self, queries, keys, values, mask):
442
 
443
  if self.sparsity_factor == 1:
444
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
 
454
 
455
  return keys, values, mask
456
 
457
+ def get_sparse_tokens_with_block_stride(self, queries, keys, values, mask):
458
 
459
  if self.sparsity_factor == 1:
460
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
 
474
 
475
  return keys, values, mask
476
 
477
+ def get_sparse_tokens_with_lsh(self, queries, keys, values, mask):
478
 
479
  if self.sparsity_factor == 1:
480
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
481
 
482
+ if self.sparsity_factor == self.sparse_block_size:
483
+ return self.get_sparse_tokens_with_bos_pooling(queries, keys, values, mask)
484
+
485
  block_size = min(self.block_size, self.sparse_block_size)
486
  keys = self.chunk(keys, block_size)
487
  values = self.chunk(values, block_size)
 
529
 
530
  return keys[..., :output_size, :], values[..., :output_size, :], mask[..., :output_size, :]
531
 
532
+ def get_sparse_tokens_with_bos_pooling(self, queries, keys, values, mask):
533
+
534
+ if self.sparsity_factor == 1:
535
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
536
+
537
+ queries = queries.unsqueeze(-3)
538
+ mask = self.chunk(mask.transpose(-1, -2), self.sparsity_factor).transpose(-1, -2)
539
+ keys = self.chunk(keys, self.sparsity_factor)
540
+ values = self.chunk(values, self.sparsity_factor)
541
+
542
+ n, h, b, t, d = keys.size()
543
+ scores = (queries[..., :1, :] @ keys.transpose(-1, -2)) / math.sqrt(d)
544
+ if mask is not None:
545
+ scores = scores + mask
546
+
547
+ scores = torch.softmax(scores, dim=-1)
548
+ keys = scores @ keys
549
+ values = scores @ values
550
+ mask = mask.mean(dim=-1)
551
+ mask[mask != torch.finfo(mask.dtype).min] = 0
552
+
553
+ return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
554
+
555
  def forward(
556
  self,
557
  hidden_states,
 
621
  sparse_key, sparse_value, sparse_mask = (None, None, None)
622
 
623
  if self.sparse_block_size and self.sparsity_factor > 0:
624
+ sparse_key, sparse_value, sparse_mask = self.get_sparse_elements(query_layer, key_layer, value_layer, attention_mask)
625
 
626
  # Expand masks on heads
627
  attention_mask = attention_mask.expand(-1, h, -1, -1)
 
694
  def __init__(self, config: LSGPegasusConfig):
695
 
696
  super().__init__(config)
697
+ self.self_attn = LSGPegasusEncoderSelfAttention(
698
  config=config,
699
  embed_dim=self.embed_dim,
700
  num_heads=config.encoder_attention_heads,