ccdv commited on
Commit
e9aa08a
1 Parent(s): 99b8f99
Files changed (4) hide show
  1. README.md +20 -11
  2. config.json +1 -1
  3. modeling_lsg_bart.py +100 -484
  4. pytorch_model.bin +1 -1
README.md CHANGED
@@ -23,15 +23,23 @@ should probably proofread and complete it, then remove this comment. -->
23
  This model is a fine-tuned version of [ccdv/lsg-bart-base-4096](https://huggingface.co/ccdv/lsg-bart-base-4096) on the scientific_papers pubmed dataset. \
24
  It achieves the following results on the test set:
25
 
26
- | Length | Sparse Type | Block Size | Sparsity | Connexions | R1 | R2 | RL | RLsum |
27
- |:------ |:----------- |:---------- |:-------- | :--------- |:----- |:----- |:----- |:----- |
28
- | 4096 | Local | 256 | 0 | 768 | 47.33 | 21.67 | 28.53 | 43.67 |
29
- | 4096 | Local | 128 | 0 | 384 | 46.84 | 21.24 | 28.22 | 43.15 |
30
- | 4096 | Pooling | 128 | 4 | 644 | 47.07 | 21.41 | 28.40 | 43.36 |
31
- | 4096 | Stride | 128 | 4 | 644 | 47.02 | 21.46 | 28.33 | 43.35 |
32
- | 4096 | Norm | 128 | 4 | 644 | 47.01 | 21.32 | 28.26 | 43.33 |
33
- | 4096 | LSH | 128 | 4 | 644 | 46.92 | 21.27 | 28.26 | 43.26 |
34
-
 
 
 
 
 
 
 
 
35
 
36
 
37
  ## Model description
@@ -61,7 +69,8 @@ The following hyperparameters were used during training:
61
  - total_train_batch_size: 32
62
  - optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
63
  - lr_scheduler_type: linear
64
- - num_epochs: 7.0
 
65
 
66
  ### Generate hyperparameters
67
 
@@ -69,13 +78,13 @@ The following hyperparameters were used during generation:
69
  - dataset_name: scientific_papers
70
  - dataset_config_name: pubmed
71
  - eval_batch_size: 8
 
72
  - early_stopping: True
73
  - ignore_pad_token_for_loss: True
74
  - length_penalty: 2.0
75
  - max_length: 512
76
  - min_length: 128
77
  - num_beams: 5
78
- - num_samples: None
79
  - no_repeat_ngram_size: None
80
  - seed: 123
81
 
 
23
  This model is a fine-tuned version of [ccdv/lsg-bart-base-4096](https://huggingface.co/ccdv/lsg-bart-base-4096) on the scientific_papers pubmed dataset. \
24
  It achieves the following results on the test set:
25
 
26
+ | Length | Sparse Type | Block Size | Sparsity | Connexions | R1 | R2 | RL | RLsum |
27
+ |:------ |:------------ |:---------- |:-------- | :--------- |:----- |:----- |:----- |:----- |
28
+ | 4096 | Local | 256 | 0 | 768 | 47.37 | 21.74 | 28.59 | 43.67 |
29
+ | 4096 | Local | 128 | 0 | 384 | 47.02 | 21.33 | 28.34 | 43.31 |
30
+ | 4096 | Pooling | 128 | 4 | 644 | 47.11 | 21.42 | 28.43 | 43.40 |
31
+ | 4096 | Stride | 128 | 4 | 644 | 47.16 | 21.49 | 28.38 | 43.44 |
32
+ | 4096 | Norm | 128 | 4 | 644 | 47.09 | 21.44 | 28.40 | 43.36 |
33
+ | 4096 | LSH | 128 | 4 | 644 | 47.11 | 21.41 | 28.41 | 43.42 |
34
+
35
+ With blocks of size 32 (lower ressources):
36
+ | Length | Sparse Type | Block Size | Sparsity | Connexions | R1 | R2 | RL | RLsum |
37
+ |:------ |:------------ |:---------- |:-------- | :--------- |:----- |:----- |:----- |:----- |
38
+ | 4096 | Pooling | 32 | 4 | 160 | 44.60 | 19.35 | 26.83 | 40.85 |
39
+ | 4096 | Stride | 32 | 4 | 160 | 45.52 | 20.07 | 27.39 | 41.75 |
40
+ | 4096 | Block Stride | 32 | 4 | 160 | 45.30 | 19.89 | 27.22 | 41.54 |
41
+ | 4096 | Norm | 32 | 4 | 160 | 44.30 | 19.05 | 26.57 | 40.47 |
42
+ | 4096 | LSH | 32 | 4 | 160 | 44.53 | 19.27 | 26.84 | 40.74 |
43
 
44
 
45
  ## Model description
 
69
  - total_train_batch_size: 32
70
  - optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
71
  - lr_scheduler_type: linear
72
+ - lr_scheduler_warmup_ratio: 0.1
73
+ - num_epochs: 8.0
74
 
75
  ### Generate hyperparameters
76
 
 
78
  - dataset_name: scientific_papers
79
  - dataset_config_name: pubmed
80
  - eval_batch_size: 8
81
+ - eval_samples: 6658
82
  - early_stopping: True
83
  - ignore_pad_token_for_loss: True
84
  - length_penalty: 2.0
85
  - max_length: 512
86
  - min_length: 128
87
  - num_beams: 5
 
88
  - no_repeat_ngram_size: None
89
  - seed: 123
90
 
config.json CHANGED
@@ -68,7 +68,7 @@
68
  "scale_embedding": false,
69
  "sparse_block_size": 0,
70
  "sparsity_factor": 2,
71
- "sparsity_type": "pooling",
72
  "task_specific_params": {
73
  "summarization": {
74
  "length_penalty": 1.0,
 
68
  "scale_embedding": false,
69
  "sparse_block_size": 0,
70
  "sparsity_factor": 2,
71
+ "sparsity_type": "none",
72
  "task_specific_params": {
73
  "summarization": {
74
  "length_penalty": 1.0,
modeling_lsg_bart.py CHANGED
@@ -41,8 +41,6 @@ class LSGBartConfig(BartConfig):
41
  ):
42
  """Constructs LSGConfig."""
43
  super().__init__(**kwargs)
44
-
45
- assert sparsity_type in ["norm", "lsh", "pooling", "stride"], "Sparsity mode must be 'norm', 'lsh' or 'pooling'"
46
 
47
  self.adaptive = adaptive
48
  self.auto_map = AUTO_MAP
@@ -55,7 +53,33 @@ class LSGBartConfig(BartConfig):
55
  self.sparse_block_size = sparse_block_size
56
  self.sparsity_factor = sparsity_factor
57
  self.sparsity_type = sparsity_type
 
 
 
 
 
 
 
 
 
 
 
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  def shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id):
61
  """
@@ -207,7 +231,7 @@ class LSGAttentionProduct(nn.Module):
207
 
208
  # Shape of blocks
209
  self.local_shapes = (self.block_size*3, self.block_size)
210
- if self.sparsity_factor > 0:
211
  self.sparse_shapes = (self.sparse_block_size*3, self.block_size//self.sparsity_factor)
212
 
213
  self.attention = BaseAttentionProduct(config)
@@ -306,9 +330,12 @@ class LSGAttentionProduct(nn.Module):
306
 
307
  size, step = self.sparse_shapes
308
 
 
 
 
309
  # n, h, t, d*2 + 1
310
  size = size*2
311
- s = (size - step) // 2
312
 
313
  # Pad before block reshaping
314
  if is_attn_mask:
@@ -326,11 +353,16 @@ class LSGAttentionProduct(nn.Module):
326
  # Make blocks
327
  hidden_states = hidden_states.unfold(-2, size=size, step=step).transpose(-1, -2)
328
 
 
 
 
 
329
  # Indexes for selection
330
- u = (size - self.block_size * 3 // self.sparsity_factor) // 2
331
  s = self.sparse_block_size
332
 
333
- return torch.cat([hidden_states[..., u-s:u, :], hidden_states[..., -u:-u+s, :]], dim=-2)
 
334
 
335
  def cat_global_sparse_local_tokens(self, x_global, x_sparse=None, x_local=None, dim=-2):
336
 
@@ -383,21 +415,15 @@ class LSGBartEncoderAttention(BaseSelfAttention):
383
  }
384
 
385
  self.sparsity_type = config.sparsity_type
386
- self.get_sparse_elements = sparse_functions[self.sparsity_type]
387
-
388
- if config.sparsity_type == "stride":
389
- if config.sparsity_factor > config.encoder_attention_heads:
390
- logger.warning(
391
- "Warning: sparsity_factor > encoder_attention_heads is not recommended for stride sparsity"
392
- )
393
 
394
  if config.sparsity_type == "lsh":
395
  self.lsh_num_pre_rounds = config.lsh_num_pre_rounds
396
-
397
  def get_sparse_tokens_with_norm(self, keys, values, mask):
398
 
399
  if self.sparsity_factor == 1:
400
- return keys, values, mask
401
 
402
  with torch.no_grad():
403
 
@@ -425,7 +451,7 @@ class LSGBartEncoderAttention(BaseSelfAttention):
425
  def get_sparse_tokens_with_pooling(self, keys, values, mask):
426
 
427
  if self.sparsity_factor == 1:
428
- return keys, values, mask
429
 
430
  keys = self.chunk(keys, self.sparsity_factor)
431
  values = self.chunk(values, self.sparsity_factor)
@@ -447,13 +473,30 @@ class LSGBartEncoderAttention(BaseSelfAttention):
447
  def get_sparse_tokens_with_stride(self, keys, values, mask):
448
 
449
  if self.sparsity_factor == 1:
450
- return keys, values, mask
451
 
452
  n, h, t, d = keys.size()
453
  sparse_idx = torch.arange(t // self.sparsity_factor, device=keys.device) * self.sparsity_factor
454
  sparse_idx = sparse_idx.reshape(1, 1, -1, 1) + (torch.arange(h, device=keys.device) % self.sparsity_factor).reshape(1, h, 1, 1)
455
  sparse_idx = sparse_idx.expand(n, h, -1, 1)
456
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
457
  keys = keys.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
458
  values = values.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
459
  mask = mask.expand(-1, h, -1, -1).transpose(-1, -2).gather(dim=-2, index=sparse_idx).transpose(-1, -2)
@@ -463,7 +506,7 @@ class LSGBartEncoderAttention(BaseSelfAttention):
463
  def get_sparse_tokens_with_lsh(self, keys, values, mask):
464
 
465
  if self.sparsity_factor == 1:
466
- return keys, values, mask
467
 
468
  block_size = min(self.block_size, self.sparse_block_size)
469
  keys = self.chunk(keys, block_size)
@@ -480,9 +523,9 @@ class LSGBartEncoderAttention(BaseSelfAttention):
480
  extra_factor = 1
481
 
482
  for _ in range(self.lsh_num_pre_rounds):
483
- keys, values, mask = self.lsg_round(keys, values, mask, t*extra_factor)
484
 
485
- keys, values, mask = self.lsg_round(keys, values, mask, t//self.sparsity_factor)
486
  keys /= mask + 1e-8
487
  values /= mask + 1e-8
488
 
@@ -490,7 +533,7 @@ class LSGBartEncoderAttention(BaseSelfAttention):
490
 
491
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
492
 
493
- def lsg_round(self, keys, values, mask, output_size):
494
 
495
  with torch.no_grad():
496
 
@@ -1130,7 +1173,8 @@ class LSGBartEncoder(LSGBartPretrainedModel):
1130
 
1131
  # else adaptive sequence length
1132
  elif self.adaptive:
1133
- s = int(torch.max(attention_mask.sum(dim=-1)))
 
1134
  if s < t and self.block_size is not None:
1135
  s = max(2, s // self.block_size + 1) * self.block_size if s > b else s
1136
  if input_ids is not None:
@@ -1293,6 +1337,7 @@ class LSGBartDecoder(LSGBartPretrainedModel):
1293
  self.padding_idx = config.pad_token_id
1294
  self.max_target_positions = config.max_position_embeddings
1295
  self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
 
1296
 
1297
  if embed_tokens is not None:
1298
  self.embed_tokens = embed_tokens
@@ -1335,6 +1380,15 @@ class LSGBartDecoder(LSGBartPretrainedModel):
1335
 
1336
  return combined_attention_mask
1337
 
 
 
 
 
 
 
 
 
 
1338
  def forward(
1339
  self,
1340
  input_ids=None,
@@ -1375,12 +1429,14 @@ class LSGBartDecoder(LSGBartPretrainedModel):
1375
  if inputs_embeds is None:
1376
  inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
1377
 
1378
- # Cut
1379
- if attention_mask is not None:
1380
- max_len = int(attention_mask.sum(dim=-1).max())
1381
- inputs_embeds = inputs_embeds[:, :max_len]
1382
- attention_mask = attention_mask[..., :max_len]
1383
- input_shape = inputs_embeds.size()[:-1]
 
 
1384
 
1385
  attention_mask = self._prepare_decoder_attention_mask(
1386
  attention_mask, input_shape, inputs_embeds, past_key_values_length
@@ -1474,6 +1530,9 @@ class LSGBartDecoder(LSGBartPretrainedModel):
1474
  if encoder_hidden_states is not None:
1475
  all_cross_attentions += (layer_outputs[2],)
1476
 
 
 
 
1477
  # add hidden states from the last decoder layer
1478
  if output_hidden_states:
1479
  all_hidden_states += (hidden_states,)
@@ -1610,14 +1669,14 @@ class LSGBartModel(LSGBartPretrainedModel):
1610
  )
1611
 
1612
 
1613
- class LSGBartForConditionalGeneration(LSGBartPretrainedModel):
1614
 
1615
  base_model_prefix = "model"
1616
  _keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head\.weight"]
1617
 
1618
  def __init__(self, config):
1619
 
1620
- super().__init__(config)
1621
  self.model = LSGBartModel(config)
1622
  self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
1623
  self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
@@ -1625,157 +1684,12 @@ class LSGBartForConditionalGeneration(LSGBartPretrainedModel):
1625
  # Initialize weights and apply final processing
1626
  self.post_init()
1627
 
1628
- def get_encoder(self):
1629
- return self.model.get_encoder()
1630
-
1631
- def get_decoder(self):
1632
- return self.model.get_decoder()
1633
 
1634
- def resize_token_embeddings(self, new_num_tokens):
1635
- new_embeddings = super().resize_token_embeddings(new_num_tokens)
1636
- self._resize_final_logits_bias(new_num_tokens)
1637
- return new_embeddings
1638
 
1639
- def _resize_final_logits_bias(self, new_num_tokens):
1640
- old_num_tokens = self.final_logits_bias.shape[-1]
1641
- if new_num_tokens <= old_num_tokens:
1642
- new_bias = self.final_logits_bias[:, :new_num_tokens]
1643
- else:
1644
- extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
1645
- new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
1646
- self.register_buffer("final_logits_bias", new_bias)
1647
-
1648
- def get_output_embeddings(self):
1649
- return self.lm_head
1650
-
1651
- def set_output_embeddings(self, new_embeddings):
1652
- self.lm_head = new_embeddings
1653
-
1654
- def forward(
1655
- self,
1656
- input_ids=None,
1657
- attention_mask=None,
1658
- decoder_input_ids=None,
1659
- decoder_attention_mask=None,
1660
- head_mask=None,
1661
- decoder_head_mask=None,
1662
- cross_attn_head_mask=None,
1663
- encoder_outputs=None,
1664
- past_key_values=None,
1665
- inputs_embeds=None,
1666
- decoder_inputs_embeds=None,
1667
- labels=None,
1668
- use_cache=None,
1669
- output_attentions=None,
1670
- output_hidden_states=None,
1671
- return_dict=None,
1672
- ):
1673
 
1674
- r"""
1675
- labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1676
- Labels for computing the masked language modeling loss. Indices should either be in ``[0, ...,
1677
- config.vocab_size]`` or -100 (see ``input_ids`` docstring). Tokens with indices set to ``-100`` are ignored
1678
- (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``.
1679
- Returns:
1680
- """
1681
-
1682
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1683
-
1684
- if labels is not None:
1685
- if decoder_input_ids is None and decoder_inputs_embeds is None:
1686
- decoder_input_ids = shift_tokens_right(
1687
- labels, self.config.pad_token_id, self.config.decoder_start_token_id
1688
- )
1689
-
1690
- outputs = self.model(
1691
- input_ids,
1692
- attention_mask=attention_mask,
1693
- decoder_input_ids=decoder_input_ids,
1694
- encoder_outputs=encoder_outputs,
1695
- decoder_attention_mask=decoder_attention_mask,
1696
- head_mask=head_mask,
1697
- decoder_head_mask=decoder_head_mask,
1698
- cross_attn_head_mask=cross_attn_head_mask,
1699
- past_key_values=past_key_values,
1700
- inputs_embeds=inputs_embeds,
1701
- decoder_inputs_embeds=decoder_inputs_embeds,
1702
- use_cache=use_cache,
1703
- output_attentions=output_attentions,
1704
- output_hidden_states=output_hidden_states,
1705
- return_dict=return_dict,
1706
- )
1707
-
1708
-
1709
- lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
1710
-
1711
- masked_lm_loss = None
1712
- if labels is not None:
1713
- loss_fct = CrossEntropyLoss()
1714
- masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
1715
-
1716
- if not return_dict:
1717
- output = (lm_logits,) + outputs[1:]
1718
- return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1719
-
1720
- return Seq2SeqLMOutput(
1721
- loss=masked_lm_loss,
1722
- logits=lm_logits,
1723
- past_key_values=outputs.past_key_values,
1724
- decoder_hidden_states=outputs.decoder_hidden_states,
1725
- decoder_attentions=outputs.decoder_attentions,
1726
- cross_attentions=outputs.cross_attentions,
1727
- encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1728
- encoder_hidden_states=outputs.encoder_hidden_states,
1729
- encoder_attentions=outputs.encoder_attentions,
1730
- )
1731
-
1732
- def prepare_inputs_for_generation(
1733
- self,
1734
- decoder_input_ids,
1735
- past=None,
1736
- attention_mask=None,
1737
- head_mask=None,
1738
- decoder_head_mask=None,
1739
- cross_attn_head_mask=None,
1740
- use_cache=None,
1741
- encoder_outputs=None,
1742
- **kwargs
1743
- ):
1744
- # cut decoder_input_ids if past is used
1745
- if past is not None:
1746
- decoder_input_ids = decoder_input_ids[:, -1:]
1747
-
1748
- return {
1749
- "input_ids": None, # encoder_outputs is defined. input_ids not needed
1750
- "encoder_outputs": encoder_outputs,
1751
- "past_key_values": past,
1752
- "decoder_input_ids": decoder_input_ids,
1753
- "attention_mask": attention_mask,
1754
- "head_mask": head_mask,
1755
- "decoder_head_mask": decoder_head_mask,
1756
- "cross_attn_head_mask": cross_attn_head_mask,
1757
- "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
1758
- }
1759
-
1760
- def prepare_decoder_input_ids_from_labels(self, labels):
1761
- return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
1762
-
1763
- @staticmethod
1764
- def _reorder_cache(past, beam_idx):
1765
- reordered_past = ()
1766
- for layer_past in past:
1767
- # cached cross_attention states don't have to be reordered -> they are always the same
1768
- reordered_past += (
1769
- tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
1770
- )
1771
- return reordered_past
1772
-
1773
-
1774
- class LSGBartForSequenceClassification(LSGBartPretrainedModel):
1775
-
1776
- def __init__(self, config, **kwargs):
1777
-
1778
- super().__init__(config, **kwargs)
1779
  self.model = LSGBartModel(config)
1780
  self.classification_head = LSGBartClassificationHead(
1781
  config.d_model,
@@ -1786,115 +1700,12 @@ class LSGBartForSequenceClassification(LSGBartPretrainedModel):
1786
  self.model._init_weights(self.classification_head.dense)
1787
  self.model._init_weights(self.classification_head.out_proj)
1788
 
1789
- def forward(
1790
- self,
1791
- input_ids=None,
1792
- attention_mask=None,
1793
- decoder_input_ids=None,
1794
- decoder_attention_mask=None,
1795
- head_mask=None,
1796
- decoder_head_mask=None,
1797
- cross_attn_head_mask=None,
1798
- encoder_outputs=None,
1799
- inputs_embeds=None,
1800
- decoder_inputs_embeds=None,
1801
- labels=None,
1802
- use_cache=None,
1803
- output_attentions=None,
1804
- output_hidden_states=None,
1805
- return_dict=None,
1806
- ):
1807
-
1808
- r"""
1809
- labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1810
- Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
1811
- config.num_labels - 1]`. If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1812
- """
1813
 
1814
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1815
- if labels is not None:
1816
- use_cache = False
1817
 
1818
- if input_ids is None and inputs_embeds is not None:
1819
- raise NotImplementedError(
1820
- f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
1821
- )
1822
 
1823
- outputs = self.model(
1824
- input_ids,
1825
- attention_mask=attention_mask,
1826
- decoder_input_ids=decoder_input_ids,
1827
- decoder_attention_mask=decoder_attention_mask,
1828
- head_mask=head_mask,
1829
- decoder_head_mask=decoder_head_mask,
1830
- cross_attn_head_mask=cross_attn_head_mask,
1831
- encoder_outputs=encoder_outputs,
1832
- inputs_embeds=inputs_embeds,
1833
- decoder_inputs_embeds=decoder_inputs_embeds,
1834
- use_cache=use_cache,
1835
- output_attentions=output_attentions,
1836
- output_hidden_states=output_hidden_states,
1837
- return_dict=return_dict,
1838
- )
1839
- hidden_states = outputs[0] # last hidden state
1840
-
1841
- eos_mask = input_ids.eq(self.config.eos_token_id)
1842
-
1843
- t, t_ = eos_mask.size()[-1], hidden_states.size()[-2]
1844
- if t > t_:
1845
- eos_mask = eos_mask[:, :t_]
1846
-
1847
- if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
1848
- raise ValueError("All examples must have the same number of <eos> tokens.")
1849
- sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
1850
- :, -1, :
1851
- ]
1852
- logits = self.classification_head(sentence_representation)
1853
-
1854
- loss = None
1855
- if labels is not None:
1856
- if self.config.problem_type is None:
1857
- if self.config.num_labels == 1:
1858
- self.config.problem_type = "regression"
1859
- elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1860
- self.config.problem_type = "single_label_classification"
1861
- else:
1862
- self.config.problem_type = "multi_label_classification"
1863
-
1864
- if self.config.problem_type == "regression":
1865
- loss_fct = MSELoss()
1866
- if self.config.num_labels == 1:
1867
- loss = loss_fct(logits.squeeze(), labels.squeeze())
1868
- else:
1869
- loss = loss_fct(logits, labels)
1870
- elif self.config.problem_type == "single_label_classification":
1871
- loss_fct = CrossEntropyLoss()
1872
- loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
1873
- elif self.config.problem_type == "multi_label_classification":
1874
- loss_fct = BCEWithLogitsLoss()
1875
- loss = loss_fct(logits, labels)
1876
- if not return_dict:
1877
- output = (logits,) + outputs[1:]
1878
- return ((loss,) + output) if loss is not None else output
1879
-
1880
- return Seq2SeqSequenceClassifierOutput(
1881
- loss=loss,
1882
- logits=logits,
1883
- past_key_values=outputs.past_key_values,
1884
- decoder_hidden_states=outputs.decoder_hidden_states,
1885
- decoder_attentions=outputs.decoder_attentions,
1886
- cross_attentions=outputs.cross_attentions,
1887
- encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1888
- encoder_hidden_states=outputs.encoder_hidden_states,
1889
- encoder_attentions=outputs.encoder_attentions,
1890
- )
1891
-
1892
-
1893
- class LSGBartForQuestionAnswering(LSGBartPretrainedModel):
1894
-
1895
- def __init__(self, config):
1896
-
1897
- super().__init__(config)
1898
 
1899
  config.num_labels = 2
1900
  self.num_labels = config.num_labels
@@ -1904,102 +1715,6 @@ class LSGBartForQuestionAnswering(LSGBartPretrainedModel):
1904
 
1905
  self.model._init_weights(self.qa_outputs)
1906
 
1907
- def forward(
1908
- self,
1909
- input_ids=None,
1910
- attention_mask=None,
1911
- decoder_input_ids=None,
1912
- decoder_attention_mask=None,
1913
- head_mask=None,
1914
- decoder_head_mask=None,
1915
- cross_attn_head_mask=None,
1916
- encoder_outputs=None,
1917
- start_positions=None,
1918
- end_positions=None,
1919
- inputs_embeds=None,
1920
- decoder_inputs_embeds=None,
1921
- use_cache=None,
1922
- output_attentions=None,
1923
- output_hidden_states=None,
1924
- return_dict=None,
1925
- ):
1926
-
1927
- r"""
1928
- start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1929
- Labels for position (index) of the start of the labelled span for computing the token classification loss.
1930
- Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1931
- are not taken into account for computing the loss.
1932
- end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1933
- Labels for position (index) of the end of the labelled span for computing the token classification loss.
1934
- Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1935
- are not taken into account for computing the loss.
1936
- """
1937
-
1938
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1939
- if start_positions is not None and end_positions is not None:
1940
- use_cache = False
1941
-
1942
- outputs = self.model(
1943
- input_ids,
1944
- attention_mask=attention_mask,
1945
- decoder_input_ids=decoder_input_ids,
1946
- decoder_attention_mask=decoder_attention_mask,
1947
- head_mask=head_mask,
1948
- decoder_head_mask=decoder_head_mask,
1949
- cross_attn_head_mask=cross_attn_head_mask,
1950
- encoder_outputs=encoder_outputs,
1951
- inputs_embeds=inputs_embeds,
1952
- decoder_inputs_embeds=decoder_inputs_embeds,
1953
- use_cache=use_cache,
1954
- output_attentions=output_attentions,
1955
- output_hidden_states=output_hidden_states,
1956
- return_dict=return_dict,
1957
- )
1958
-
1959
- sequence_output = outputs[0]
1960
-
1961
- logits = self.qa_outputs(sequence_output)
1962
- start_logits, end_logits = logits.split(1, dim=-1)
1963
- start_logits = start_logits.squeeze(-1).contiguous()
1964
- end_logits = end_logits.squeeze(-1).contiguous()
1965
-
1966
- total_loss = None
1967
- if start_positions is not None and end_positions is not None:
1968
- # If we are on multi-GPU, split add a dimension
1969
- if len(start_positions.size()) > 1:
1970
- start_positions = start_positions.squeeze(-1)
1971
- if len(end_positions.size()) > 1:
1972
- end_positions = end_positions.squeeze(-1)
1973
- # sometimes the start/end positions are outside our model inputs, we ignore these terms
1974
- ignored_index = start_logits.size(1)
1975
- start_positions = start_positions.clamp(0, ignored_index)
1976
- end_positions = end_positions.clamp(0, ignored_index)
1977
-
1978
- loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1979
- start_loss = loss_fct(start_logits, start_positions)
1980
- end_loss = loss_fct(end_logits, end_positions)
1981
- total_loss = (start_loss + end_loss) / 2
1982
-
1983
- if not return_dict:
1984
- output = (
1985
- start_logits,
1986
- end_logits,
1987
- ) + outputs[1:]
1988
- return ((total_loss,) + output) if total_loss is not None else output
1989
-
1990
- return Seq2SeqQuestionAnsweringModelOutput(
1991
- loss=total_loss,
1992
- start_logits=start_logits,
1993
- end_logits=end_logits,
1994
- past_key_values=outputs.past_key_values,
1995
- decoder_hidden_states=outputs.decoder_hidden_states,
1996
- decoder_attentions=outputs.decoder_attentions,
1997
- cross_attentions=outputs.cross_attentions,
1998
- encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1999
- encoder_hidden_states=outputs.encoder_hidden_states,
2000
- encoder_attentions=outputs.encoder_attentions,
2001
- )
2002
-
2003
 
2004
  class LSGBartDecoderWrapper(LSGBartPretrainedModel):
2005
  """
@@ -2007,22 +1722,22 @@ class LSGBartDecoderWrapper(LSGBartPretrainedModel):
2007
  used in combination with the :class:`~transformers.EncoderDecoderModel` framework.
2008
  """
2009
 
2010
- def __init__(self, config):
2011
  super().__init__(config)
2012
- self.decoder = BartDecoder(config)
2013
 
2014
  def forward(self, *args, **kwargs):
2015
  return self.decoder(*args, **kwargs)
2016
 
2017
 
2018
- class LSGBartForCausalLM(LSGBartPretrainedModel):
2019
 
2020
- def __init__(self, config):
2021
 
2022
- super().__init__(config)
2023
  config = copy.deepcopy(config)
2024
  config.is_decoder = True
2025
  config.is_encoder_decoder = False
 
2026
  self.model = LSGBartDecoderWrapper(config)
2027
 
2028
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
@@ -2030,105 +1745,6 @@ class LSGBartForCausalLM(LSGBartPretrainedModel):
2030
  # Initialize weights and apply final processing
2031
  self.post_init()
2032
 
2033
- def get_input_embeddings(self):
2034
- return self.model.decoder.embed_tokens
2035
-
2036
- def set_input_embeddings(self, value):
2037
- self.model.decoder.embed_tokens = value
2038
-
2039
- def get_output_embeddings(self):
2040
- return self.lm_head
2041
-
2042
- def set_output_embeddings(self, new_embeddings):
2043
- self.lm_head = new_embeddings
2044
-
2045
- def set_decoder(self, decoder):
2046
- self.model.decoder = decoder
2047
-
2048
- def get_decoder(self):
2049
- return self.model.decoder
2050
-
2051
- def forward(
2052
- self,
2053
- input_ids=None,
2054
- attention_mask=None,
2055
- encoder_hidden_states=None,
2056
- encoder_attention_mask=None,
2057
- head_mask=None,
2058
- cross_attn_head_mask=None,
2059
- past_key_values=None,
2060
- inputs_embeds=None,
2061
- labels=None,
2062
- use_cache=None,
2063
- output_attentions=None,
2064
- output_hidden_states=None,
2065
- return_dict=None,
2066
- ):
2067
-
2068
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
2069
- output_hidden_states = (
2070
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
2071
- )
2072
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
2073
-
2074
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
2075
- outputs = self.model.decoder(
2076
- input_ids=input_ids,
2077
- attention_mask=attention_mask,
2078
- encoder_hidden_states=encoder_hidden_states,
2079
- encoder_attention_mask=encoder_attention_mask,
2080
- head_mask=head_mask,
2081
- cross_attn_head_mask=cross_attn_head_mask,
2082
- past_key_values=past_key_values,
2083
- inputs_embeds=inputs_embeds,
2084
- use_cache=use_cache,
2085
- output_attentions=output_attentions,
2086
- output_hidden_states=output_hidden_states,
2087
- return_dict=return_dict,
2088
- )
2089
-
2090
- logits = self.lm_head(outputs[0])
2091
-
2092
- loss = None
2093
- if labels is not None:
2094
- loss_fct = CrossEntropyLoss()
2095
- loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
2096
-
2097
- if not return_dict:
2098
- output = (logits,) + outputs[1:]
2099
- return (loss,) + output if loss is not None else output
2100
-
2101
- return CausalLMOutputWithCrossAttentions(
2102
- loss=loss,
2103
- logits=logits,
2104
- past_key_values=outputs.past_key_values,
2105
- hidden_states=outputs.hidden_states,
2106
- attentions=outputs.attentions,
2107
- cross_attentions=outputs.cross_attentions,
2108
- )
2109
-
2110
- def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
2111
- # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
2112
- if attention_mask is None:
2113
- attention_mask = input_ids.new_ones(input_ids.shape)
2114
-
2115
- if past:
2116
- input_ids = input_ids[:, -1:]
2117
- # first step, decoder_cached_states are empty
2118
- return {
2119
- "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
2120
- "attention_mask": attention_mask,
2121
- "past_key_values": past,
2122
- "use_cache": use_cache,
2123
- }
2124
-
2125
- @staticmethod
2126
- def _reorder_cache(past, beam_idx):
2127
- reordered_past = ()
2128
- for layer_past in past:
2129
- reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
2130
- return reordered_past
2131
-
2132
 
2133
  def str_to_class(classname):
2134
  return getattr(sys.modules[__name__], classname)
 
41
  ):
42
  """Constructs LSGConfig."""
43
  super().__init__(**kwargs)
 
 
44
 
45
  self.adaptive = adaptive
46
  self.auto_map = AUTO_MAP
 
53
  self.sparse_block_size = sparse_block_size
54
  self.sparsity_factor = sparsity_factor
55
  self.sparsity_type = sparsity_type
56
+
57
+ if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride"]:
58
+ logger.warning(
59
+ "[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride'], setting sparsity_type=None, computation will skip sparse attention")
60
+ self.sparsity_type = None
61
+
62
+ if self.sparsity_type == "stride":
63
+ if self.sparsity_factor > self.encoder_attention_heads:
64
+ logger.warning(
65
+ "[WARNING CONFIG]: sparsity_factor > encoder_attention_heads is not recommended for stride sparsity"
66
+ )
67
 
68
+ if self.num_global_tokens < 1:
69
+ logger.warning(
70
+ "[WARNING CONFIG]: num_global_tokens < 1 is not compatible, setting num_global_tokens=1"
71
+ )
72
+ self.num_global_tokens = 1
73
+ elif self.num_global_tokens > 512:
74
+ logger.warning(
75
+ "[WARNING CONFIG]: num_global_tokens > 512 is not compatible, setting num_global_tokens=512"
76
+ )
77
+ self.num_global_tokens = 512
78
+
79
+ if self.sparsity_factor > 0:
80
+ assert self.block_size % self.sparsity_factor == 0, "[ERROR CONFIG]: block_size must be divisible by sparsity_factor"
81
+ assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
82
+
83
 
84
  def shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id):
85
  """
 
231
 
232
  # Shape of blocks
233
  self.local_shapes = (self.block_size*3, self.block_size)
234
+ if self.sparse_block_size and self.sparsity_factor > 0:
235
  self.sparse_shapes = (self.sparse_block_size*3, self.block_size//self.sparsity_factor)
236
 
237
  self.attention = BaseAttentionProduct(config)
 
330
 
331
  size, step = self.sparse_shapes
332
 
333
+ # In case of odd case
334
+ odd_offset = (step % 2)
335
+
336
  # n, h, t, d*2 + 1
337
  size = size*2
338
+ s = (size - step) // 2 + odd_offset
339
 
340
  # Pad before block reshaping
341
  if is_attn_mask:
 
353
  # Make blocks
354
  hidden_states = hidden_states.unfold(-2, size=size, step=step).transpose(-1, -2)
355
 
356
+ # Fix case where block_size == sparsify_factor
357
+ if odd_offset:
358
+ hidden_states = hidden_states[..., :-1, :, :]
359
+
360
  # Indexes for selection
361
+ u = (size - self.block_size * 3 // self.sparsity_factor) // 2 + odd_offset
362
  s = self.sparse_block_size
363
 
364
+ u_ = u + odd_offset
365
+ return torch.cat([hidden_states[..., u-s:u, :], hidden_states[..., -u_:-u_+s, :]], dim=-2)
366
 
367
  def cat_global_sparse_local_tokens(self, x_global, x_sparse=None, x_local=None, dim=-2):
368
 
 
415
  }
416
 
417
  self.sparsity_type = config.sparsity_type
418
+ self.get_sparse_elements = sparse_functions.get(self.sparsity_type, lambda x, y, z: (None, None, None))
 
 
 
 
 
 
419
 
420
  if config.sparsity_type == "lsh":
421
  self.lsh_num_pre_rounds = config.lsh_num_pre_rounds
422
+
423
  def get_sparse_tokens_with_norm(self, keys, values, mask):
424
 
425
  if self.sparsity_factor == 1:
426
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
427
 
428
  with torch.no_grad():
429
 
 
451
  def get_sparse_tokens_with_pooling(self, keys, values, mask):
452
 
453
  if self.sparsity_factor == 1:
454
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
455
 
456
  keys = self.chunk(keys, self.sparsity_factor)
457
  values = self.chunk(values, self.sparsity_factor)
 
473
  def get_sparse_tokens_with_stride(self, keys, values, mask):
474
 
475
  if self.sparsity_factor == 1:
476
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
477
 
478
  n, h, t, d = keys.size()
479
  sparse_idx = torch.arange(t // self.sparsity_factor, device=keys.device) * self.sparsity_factor
480
  sparse_idx = sparse_idx.reshape(1, 1, -1, 1) + (torch.arange(h, device=keys.device) % self.sparsity_factor).reshape(1, h, 1, 1)
481
  sparse_idx = sparse_idx.expand(n, h, -1, 1)
482
 
483
+ """
484
+ t, b = self.block_size, t // self.block_size
485
+ sparse_idx = torch.arange(t // self.sparsity_factor, device=keys.device) * self.sparsity_factor
486
+ sparse_idx = sparse_idx.reshape(1, 1, 1, -1, 1) + (torch.arange(h, device=keys.device) % self.sparsity_factor).reshape(1, h, 1, 1, 1)
487
+ sparse_idx = sparse_idx + torch.arange(b, device=keys.device).reshape(1, 1, -1, 1, 1) * t
488
+ sparse_idx = sparse_idx.reshape(1, h, -1, 1).expand(n, h, -1, 1)
489
+
490
+
491
+ t, b = self.block_size, t // self.block_size
492
+ sparse_idx = torch.arange(t // self.sparsity_factor, device=keys.device)
493
+ sparse_idx = sparse_idx.reshape(1, 1, 1, -1, 1) + torch.arange(h, device=keys.device).reshape(1, h, 1, 1, 1) * (t // self.sparsity_factor)
494
+ sparse_idx = (sparse_idx % t)
495
+ #sparse_idx[..., -t//2:, :] = (sparse_idx[..., -t//2:, :] + t//2) % t
496
+ sparse_idx = sparse_idx + torch.arange(b, device=keys.device).reshape(1, 1, -1, 1, 1) * t
497
+ sparse_idx = sparse_idx.reshape(1, h, -1, 1).expand(n, h, -1, 1)
498
+ """
499
+
500
  keys = keys.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
501
  values = values.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
502
  mask = mask.expand(-1, h, -1, -1).transpose(-1, -2).gather(dim=-2, index=sparse_idx).transpose(-1, -2)
 
506
  def get_sparse_tokens_with_lsh(self, keys, values, mask):
507
 
508
  if self.sparsity_factor == 1:
509
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
510
 
511
  block_size = min(self.block_size, self.sparse_block_size)
512
  keys = self.chunk(keys, block_size)
 
523
  extra_factor = 1
524
 
525
  for _ in range(self.lsh_num_pre_rounds):
526
+ keys, values, mask = self.lsh_round(keys, values, mask, t*extra_factor)
527
 
528
+ keys, values, mask = self.lsh_round(keys, values, mask, t//self.sparsity_factor)
529
  keys /= mask + 1e-8
530
  values /= mask + 1e-8
531
 
 
533
 
534
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
535
 
536
+ def lsh_round(self, keys, values, mask, output_size):
537
 
538
  with torch.no_grad():
539
 
 
1173
 
1174
  # else adaptive sequence length
1175
  elif self.adaptive:
1176
+ # Get last non zero mask index
1177
+ s = int(attention_mask.cumsum(dim=-1).argmax(dim=-1).max()) + 1
1178
  if s < t and self.block_size is not None:
1179
  s = max(2, s // self.block_size + 1) * self.block_size if s > b else s
1180
  if input_ids is not None:
 
1337
  self.padding_idx = config.pad_token_id
1338
  self.max_target_positions = config.max_position_embeddings
1339
  self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
1340
+ self.adaptive = config.adaptive
1341
 
1342
  if embed_tokens is not None:
1343
  self.embed_tokens = embed_tokens
 
1380
 
1381
  return combined_attention_mask
1382
 
1383
+ def resize_inputs(self, inputs_embeds, attention_mask):
1384
+ pad = 0
1385
+
1386
+ max_len = int(attention_mask.sum(dim=-1).max())
1387
+ pad = attention_mask.size()[-1] - max_len
1388
+ inputs_embeds = inputs_embeds[:, :max_len]
1389
+ attention_mask = attention_mask[..., :max_len]
1390
+ return pad, inputs_embeds, attention_mask
1391
+
1392
  def forward(
1393
  self,
1394
  input_ids=None,
 
1429
  if inputs_embeds is None:
1430
  inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
1431
 
1432
+ # Resize to reduce computation
1433
+ pad = 0
1434
+ if self.adaptive:
1435
+ if attention_mask is not None:
1436
+ pad, inputs_embeds, attention_mask = self.resize_inputs(inputs_embeds, attention_mask)
1437
+ input_shape = inputs_embeds.size()[:-1]
1438
+ if encoder_attention_mask is not None:
1439
+ _, encoder_hidden_states, encoder_attention_mask = self.resize_inputs(encoder_hidden_states, encoder_attention_mask)
1440
 
1441
  attention_mask = self._prepare_decoder_attention_mask(
1442
  attention_mask, input_shape, inputs_embeds, past_key_values_length
 
1530
  if encoder_hidden_states is not None:
1531
  all_cross_attentions += (layer_outputs[2],)
1532
 
1533
+ # Resize to original shape
1534
+ hidden_states = torch.nn.functional.pad(hidden_states.transpose(-1, -2), pad=(0, pad), value=0).transpose(-1, -2)
1535
+
1536
  # add hidden states from the last decoder layer
1537
  if output_hidden_states:
1538
  all_hidden_states += (hidden_states,)
 
1669
  )
1670
 
1671
 
1672
+ class LSGBartForConditionalGeneration(BartForConditionalGeneration, LSGBartPretrainedModel):
1673
 
1674
  base_model_prefix = "model"
1675
  _keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head\.weight"]
1676
 
1677
  def __init__(self, config):
1678
 
1679
+ LSGBartPretrainedModel.__init__(self, config)
1680
  self.model = LSGBartModel(config)
1681
  self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
1682
  self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
 
1684
  # Initialize weights and apply final processing
1685
  self.post_init()
1686
 
 
 
 
 
 
1687
 
1688
+ class LSGBartForSequenceClassification(BartForSequenceClassification, LSGBartPretrainedModel):
 
 
 
1689
 
1690
+ def __init__(self, config: LSGBartConfig, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1691
 
1692
+ LSGBartPretrainedModel.__init__(self, config, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1693
  self.model = LSGBartModel(config)
1694
  self.classification_head = LSGBartClassificationHead(
1695
  config.d_model,
 
1700
  self.model._init_weights(self.classification_head.dense)
1701
  self.model._init_weights(self.classification_head.out_proj)
1702
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1703
 
1704
+ class LSGBartForQuestionAnswering(BartForQuestionAnswering, LSGBartPretrainedModel):
 
 
1705
 
1706
+ def __init__(self, config: LSGBartConfig):
 
 
 
1707
 
1708
+ LSGBartPretrainedModel.__init__(self, config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1709
 
1710
  config.num_labels = 2
1711
  self.num_labels = config.num_labels
 
1715
 
1716
  self.model._init_weights(self.qa_outputs)
1717
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1718
 
1719
  class LSGBartDecoderWrapper(LSGBartPretrainedModel):
1720
  """
 
1722
  used in combination with the :class:`~transformers.EncoderDecoderModel` framework.
1723
  """
1724
 
1725
+ def __init__(self, config: LSGBartConfig):
1726
  super().__init__(config)
1727
+ self.decoder = LSGBartDecoder(config)
1728
 
1729
  def forward(self, *args, **kwargs):
1730
  return self.decoder(*args, **kwargs)
1731
 
1732
 
1733
+ class LSGBartForCausalLM(BartForCausalLM, LSGBartPretrainedModel):
1734
 
1735
+ def __init__(self, config: LSGBartConfig):
1736
 
 
1737
  config = copy.deepcopy(config)
1738
  config.is_decoder = True
1739
  config.is_encoder_decoder = False
1740
+ LSGBartPretrainedModel.__init__(self, config)
1741
  self.model = LSGBartDecoderWrapper(config)
1742
 
1743
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 
1745
  # Initialize weights and apply final processing
1746
  self.post_init()
1747
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1748
 
1749
  def str_to_class(classname):
1750
  return getattr(sys.modules[__name__], classname)
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:0baa2aecffe0dc5c00cb4f23b89134663008055a409645e422850c2e5d78240f
3
  size 578416695
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:933a1e3345672ba1ca8fb2956ca511a720e4a4ae54fe466c80c12c4a30df281b
3
  size 578416695