ccdv commited on
Commit
a917f9e
1 Parent(s): 1389bb8
README.md CHANGED
@@ -21,19 +21,26 @@ should probably proofread and complete it, then remove this comment. -->
21
  # ccdv/lsg-bart-base-16384-arxiv
22
 
23
  This model is a fine-tuned version of [ccdv/lsg-bart-base-4096-arxiv](https://huggingface.co/ccdv/lsg-bart-base-4096-arxiv) on the scientific_papers arxiv dataset. \
 
24
  It achieves the following results on the test set:
25
 
26
- | Length | Global tokens | Sparse Type | Block Size | Sparsity | Connexions | R1 | R2 | RL | RLsum |
27
  |:------ |:------------- |:----------- |:---------- |:-------- | :--------- |:----- |:----- |:----- |:----- |
28
- | 16384 | 64 | - | 256 | 0 | 768 | 48.55 | 20.76 | 28.39 | 44.03 |
 
 
29
 
 
 
 
 
30
 
31
  ## Model description
32
  The model relies on Local-Sparse-Global attention to handle long sequences:
33
  ![attn](attn.png)
34
 
35
  The model has about ~145 millions parameters (6 encoder layers - 6 decoder layers). \
36
- The model is warm started from [ccdv/lsg-bart-base-4096-arxiv](https://huggingface.co/ccdv/lsg-bart-base-4096-arxiv), converted to handle long sequences (encoder only) and fine tuned. \
37
 
38
  ## Intended uses & limitations
39
 
@@ -49,12 +56,13 @@ More information needed
49
 
50
  The following hyperparameters were used during training:
51
  - learning_rate: 8e-05
52
- - train_batch_size: 1
53
  - seed: 42
54
- - gradient_accumulation_steps: 32
55
  - total_train_batch_size: 32
56
  - optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
57
  - lr_scheduler_type: linear
 
58
  - num_epochs: 1.0
59
 
60
  ### Generate hyperparameters
@@ -62,14 +70,14 @@ The following hyperparameters were used during training:
62
  The following hyperparameters were used during generation:
63
  - dataset_name: scientific_papers
64
  - dataset_config_name: arxiv
65
- - eval_batch_size: 2
 
66
  - early_stopping: True
67
  - ignore_pad_token_for_loss: True
68
  - length_penalty: 2.0
69
  - max_length: 320
70
- - min_length: 64
71
  - num_beams: 5
72
- - num_samples: None
73
  - no_repeat_ngram_size: None
74
  - seed: 123
75
 
 
21
  # ccdv/lsg-bart-base-16384-arxiv
22
 
23
  This model is a fine-tuned version of [ccdv/lsg-bart-base-4096-arxiv](https://huggingface.co/ccdv/lsg-bart-base-4096-arxiv) on the scientific_papers arxiv dataset. \
24
+ The model is converted to handle 16384 long sequences and fine-tuned accordingly during 1 epoch. \
25
  It achieves the following results on the test set:
26
 
27
+ | Length | Global tokens | Fine-tuning | Block Size | Sparsity | Connexions | R1 | R2 | RL | RLsum |
28
  |:------ |:------------- |:----------- |:---------- |:-------- | :--------- |:----- |:----- |:----- |:----- |
29
+ | 16384 | 64 | Full | 256 | 0 | 768 | 48.74 | 20.88 | 28.50 | 44.23 |
30
+ | 16384 | 64 | Global only | 256 | 0 | 768 | 48.08 | 20.42 | 28.00 | 43.65 |
31
+ | 16384 | 1 | None | 256 | 0 | 768 | 47.03 | 20.19 | 28.26 | 42.69 |
32
 
33
+ Reference model:
34
+ | Length | Global tokens | Fine-tuning | Block Size | Sparsity | Connexions | R1 | R2 | RL | RLsum |
35
+ |:------ |:------------- |:----------- |:---------- |:-------- | :--------- |:----- |:----- |:----- |:----- |
36
+ | 4096 | 1 | - | 256 | 0 | 768 | 46.65 | 18.91 | 26.90 | 42.18 |
37
 
38
  ## Model description
39
  The model relies on Local-Sparse-Global attention to handle long sequences:
40
  ![attn](attn.png)
41
 
42
  The model has about ~145 millions parameters (6 encoder layers - 6 decoder layers). \
43
+ The model is warm started from [ccdv/lsg-bart-base-4096-arxiv](https://huggingface.co/ccdv/lsg-bart-base-4096-arxiv), converted to handle long sequences (encoder only) and fine tuned.
44
 
45
  ## Intended uses & limitations
46
 
 
56
 
57
  The following hyperparameters were used during training:
58
  - learning_rate: 8e-05
59
+ - train_batch_size: 8
60
  - seed: 42
61
+ - gradient_accumulation_steps: 4
62
  - total_train_batch_size: 32
63
  - optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
64
  - lr_scheduler_type: linear
65
+ - lr_scheduler_warmup_ratio: 0.1
66
  - num_epochs: 1.0
67
 
68
  ### Generate hyperparameters
 
70
  The following hyperparameters were used during generation:
71
  - dataset_name: scientific_papers
72
  - dataset_config_name: arxiv
73
+ - eval_batch_size: 8
74
+ - eval_samples: 6440
75
  - early_stopping: True
76
  - ignore_pad_token_for_loss: True
77
  - length_penalty: 2.0
78
  - max_length: 320
79
+ - min_length: 32
80
  - num_beams: 5
 
81
  - no_repeat_ngram_size: None
82
  - seed: 123
83
 
all_results.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "eval_gen_len": 215.796,
3
+ "eval_loss": 1.7052853107452393,
4
+ "eval_rouge1": 48.7438,
5
+ "eval_rouge2": 20.88,
6
+ "eval_rougeL": 28.4965,
7
+ "eval_rougeLsum": 44.2266,
8
+ "eval_runtime": 18597.9286,
9
+ "eval_samples": 6440,
10
+ "eval_samples_per_second": 0.346,
11
+ "eval_steps_per_second": 0.087
12
+ }
config.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
- "_name_or_path": "ccdv/lsg-bart-base-16384-arxiv",
3
  "activation_dropout": 0.1,
4
  "activation_function": "gelu",
5
  "adaptive": true,
@@ -68,7 +68,7 @@
68
  "scale_embedding": false,
69
  "sparse_block_size": 0,
70
  "sparsity_factor": 4,
71
- "sparsity_type": "norm",
72
  "task_specific_params": {
73
  "summarization": {
74
  "length_penalty": 1.0,
@@ -90,7 +90,7 @@
90
  }
91
  },
92
  "torch_dtype": "float32",
93
- "transformers_version": "4.18.0",
94
  "use_cache": true,
95
  "vocab_size": 50265
96
  }
 
1
  {
2
+ "_name_or_path": "/data/ccondevaux/lsg/text-summarization/tmp_final/arxiv/lsg_local_16384_trained",
3
  "activation_dropout": 0.1,
4
  "activation_function": "gelu",
5
  "adaptive": true,
 
68
  "scale_embedding": false,
69
  "sparse_block_size": 0,
70
  "sparsity_factor": 4,
71
+ "sparsity_type": "none",
72
  "task_specific_params": {
73
  "summarization": {
74
  "length_penalty": 1.0,
 
90
  }
91
  },
92
  "torch_dtype": "float32",
93
+ "transformers_version": "4.19.2",
94
  "use_cache": true,
95
  "vocab_size": 50265
96
  }
eval_results.json CHANGED
@@ -1,12 +1,12 @@
1
  {
2
- "eval_gen_len": 218.5635,
3
- "eval_loss": 1.7079483270645142,
4
- "eval_rouge1": 48.5477,
5
- "eval_rouge2": 20.7579,
6
- "eval_rougeL": 28.3878,
7
- "eval_rougeLsum": 44.0302,
8
- "eval_runtime": 21593.137,
9
  "eval_samples": 6440,
10
- "eval_samples_per_second": 0.298,
11
- "eval_steps_per_second": 0.149
12
  }
 
1
  {
2
+ "eval_gen_len": 215.796,
3
+ "eval_loss": 1.7052853107452393,
4
+ "eval_rouge1": 48.7438,
5
+ "eval_rouge2": 20.88,
6
+ "eval_rougeL": 28.4965,
7
+ "eval_rougeLsum": 44.2266,
8
+ "eval_runtime": 18597.9286,
9
  "eval_samples": 6440,
10
+ "eval_samples_per_second": 0.346,
11
+ "eval_steps_per_second": 0.087
12
  }
modeling_lsg_bart.py CHANGED
@@ -54,17 +54,32 @@ class LSGBartConfig(BartConfig):
54
  self.sparsity_factor = sparsity_factor
55
  self.sparsity_type = sparsity_type
56
 
57
- if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride"]:
58
  logger.warning(
59
- "[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride'], setting sparsity_type=None, computation will skip sparse attention")
60
  self.sparsity_type = None
61
 
62
- if self.sparsity_type == "stride":
63
  if self.sparsity_factor > self.encoder_attention_heads:
64
  logger.warning(
65
- "[WARNING CONFIG]: sparsity_factor > encoder_attention_heads is not recommended for stride sparsity"
66
  )
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  def shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id):
70
  """
@@ -217,8 +232,6 @@ class LSGAttentionProduct(nn.Module):
217
  # Shape of blocks
218
  self.local_shapes = (self.block_size*3, self.block_size)
219
  if self.sparse_block_size and self.sparsity_factor > 0:
220
- assert self.block_size % self.sparsity_factor == 0, "block_size must be divisible by sparsity_factor"
221
- assert self.block_size//self.sparsity_factor >= 1, "Config is wrong, make sure block_size >= sparsity_factor"
222
  self.sparse_shapes = (self.sparse_block_size*3, self.block_size//self.sparsity_factor)
223
 
224
  self.attention = BaseAttentionProduct(config)
@@ -399,6 +412,7 @@ class LSGBartEncoderAttention(BaseSelfAttention):
399
  "pooling": self.get_sparse_tokens_with_pooling,
400
  "lsh": self.get_sparse_tokens_with_lsh,
401
  "stride": self.get_sparse_tokens_with_stride,
 
402
  }
403
 
404
  self.sparsity_type = config.sparsity_type
@@ -410,7 +424,7 @@ class LSGBartEncoderAttention(BaseSelfAttention):
410
  def get_sparse_tokens_with_norm(self, keys, values, mask):
411
 
412
  if self.sparsity_factor == 1:
413
- return keys, values, mask
414
 
415
  with torch.no_grad():
416
 
@@ -438,7 +452,7 @@ class LSGBartEncoderAttention(BaseSelfAttention):
438
  def get_sparse_tokens_with_pooling(self, keys, values, mask):
439
 
440
  if self.sparsity_factor == 1:
441
- return keys, values, mask
442
 
443
  keys = self.chunk(keys, self.sparsity_factor)
444
  values = self.chunk(values, self.sparsity_factor)
@@ -460,7 +474,7 @@ class LSGBartEncoderAttention(BaseSelfAttention):
460
  def get_sparse_tokens_with_stride(self, keys, values, mask):
461
 
462
  if self.sparsity_factor == 1:
463
- return keys, values, mask
464
 
465
  n, h, t, d = keys.size()
466
  sparse_idx = torch.arange(t // self.sparsity_factor, device=keys.device) * self.sparsity_factor
@@ -473,10 +487,30 @@ class LSGBartEncoderAttention(BaseSelfAttention):
473
 
474
  return keys, values, mask
475
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
476
  def get_sparse_tokens_with_lsh(self, keys, values, mask):
477
 
478
  if self.sparsity_factor == 1:
479
- return keys, values, mask
480
 
481
  block_size = min(self.block_size, self.sparse_block_size)
482
  keys = self.chunk(keys, block_size)
@@ -1307,6 +1341,7 @@ class LSGBartDecoder(LSGBartPretrainedModel):
1307
  self.padding_idx = config.pad_token_id
1308
  self.max_target_positions = config.max_position_embeddings
1309
  self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
 
1310
 
1311
  if embed_tokens is not None:
1312
  self.embed_tokens = embed_tokens
@@ -1349,6 +1384,15 @@ class LSGBartDecoder(LSGBartPretrainedModel):
1349
 
1350
  return combined_attention_mask
1351
 
 
 
 
 
 
 
 
 
 
1352
  def forward(
1353
  self,
1354
  input_ids=None,
@@ -1389,12 +1433,14 @@ class LSGBartDecoder(LSGBartPretrainedModel):
1389
  if inputs_embeds is None:
1390
  inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
1391
 
1392
- # Cut
1393
- if attention_mask is not None:
1394
- max_len = int(attention_mask.sum(dim=-1).max())
1395
- inputs_embeds = inputs_embeds[:, :max_len]
1396
- attention_mask = attention_mask[..., :max_len]
1397
- input_shape = inputs_embeds.size()[:-1]
 
 
1398
 
1399
  attention_mask = self._prepare_decoder_attention_mask(
1400
  attention_mask, input_shape, inputs_embeds, past_key_values_length
@@ -1488,6 +1534,9 @@ class LSGBartDecoder(LSGBartPretrainedModel):
1488
  if encoder_hidden_states is not None:
1489
  all_cross_attentions += (layer_outputs[2],)
1490
 
 
 
 
1491
  # add hidden states from the last decoder layer
1492
  if output_hidden_states:
1493
  all_hidden_states += (hidden_states,)
@@ -1624,14 +1673,14 @@ class LSGBartModel(LSGBartPretrainedModel):
1624
  )
1625
 
1626
 
1627
- class LSGBartForConditionalGeneration(LSGBartPretrainedModel):
1628
 
1629
  base_model_prefix = "model"
1630
  _keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head\.weight"]
1631
 
1632
  def __init__(self, config):
1633
 
1634
- super().__init__(config)
1635
  self.model = LSGBartModel(config)
1636
  self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
1637
  self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
@@ -1639,157 +1688,12 @@ class LSGBartForConditionalGeneration(LSGBartPretrainedModel):
1639
  # Initialize weights and apply final processing
1640
  self.post_init()
1641
 
1642
- def get_encoder(self):
1643
- return self.model.get_encoder()
1644
-
1645
- def get_decoder(self):
1646
- return self.model.get_decoder()
1647
-
1648
- def resize_token_embeddings(self, new_num_tokens):
1649
- new_embeddings = super().resize_token_embeddings(new_num_tokens)
1650
- self._resize_final_logits_bias(new_num_tokens)
1651
- return new_embeddings
1652
-
1653
- def _resize_final_logits_bias(self, new_num_tokens):
1654
- old_num_tokens = self.final_logits_bias.shape[-1]
1655
- if new_num_tokens <= old_num_tokens:
1656
- new_bias = self.final_logits_bias[:, :new_num_tokens]
1657
- else:
1658
- extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
1659
- new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
1660
- self.register_buffer("final_logits_bias", new_bias)
1661
-
1662
- def get_output_embeddings(self):
1663
- return self.lm_head
1664
-
1665
- def set_output_embeddings(self, new_embeddings):
1666
- self.lm_head = new_embeddings
1667
-
1668
- def forward(
1669
- self,
1670
- input_ids=None,
1671
- attention_mask=None,
1672
- decoder_input_ids=None,
1673
- decoder_attention_mask=None,
1674
- head_mask=None,
1675
- decoder_head_mask=None,
1676
- cross_attn_head_mask=None,
1677
- encoder_outputs=None,
1678
- past_key_values=None,
1679
- inputs_embeds=None,
1680
- decoder_inputs_embeds=None,
1681
- labels=None,
1682
- use_cache=None,
1683
- output_attentions=None,
1684
- output_hidden_states=None,
1685
- return_dict=None,
1686
- ):
1687
-
1688
- r"""
1689
- labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1690
- Labels for computing the masked language modeling loss. Indices should either be in ``[0, ...,
1691
- config.vocab_size]`` or -100 (see ``input_ids`` docstring). Tokens with indices set to ``-100`` are ignored
1692
- (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``.
1693
- Returns:
1694
- """
1695
 
1696
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1697
 
1698
- if labels is not None:
1699
- if decoder_input_ids is None and decoder_inputs_embeds is None:
1700
- decoder_input_ids = shift_tokens_right(
1701
- labels, self.config.pad_token_id, self.config.decoder_start_token_id
1702
- )
1703
 
1704
- outputs = self.model(
1705
- input_ids,
1706
- attention_mask=attention_mask,
1707
- decoder_input_ids=decoder_input_ids,
1708
- encoder_outputs=encoder_outputs,
1709
- decoder_attention_mask=decoder_attention_mask,
1710
- head_mask=head_mask,
1711
- decoder_head_mask=decoder_head_mask,
1712
- cross_attn_head_mask=cross_attn_head_mask,
1713
- past_key_values=past_key_values,
1714
- inputs_embeds=inputs_embeds,
1715
- decoder_inputs_embeds=decoder_inputs_embeds,
1716
- use_cache=use_cache,
1717
- output_attentions=output_attentions,
1718
- output_hidden_states=output_hidden_states,
1719
- return_dict=return_dict,
1720
- )
1721
-
1722
-
1723
- lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
1724
-
1725
- masked_lm_loss = None
1726
- if labels is not None:
1727
- loss_fct = CrossEntropyLoss()
1728
- masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
1729
-
1730
- if not return_dict:
1731
- output = (lm_logits,) + outputs[1:]
1732
- return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1733
-
1734
- return Seq2SeqLMOutput(
1735
- loss=masked_lm_loss,
1736
- logits=lm_logits,
1737
- past_key_values=outputs.past_key_values,
1738
- decoder_hidden_states=outputs.decoder_hidden_states,
1739
- decoder_attentions=outputs.decoder_attentions,
1740
- cross_attentions=outputs.cross_attentions,
1741
- encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1742
- encoder_hidden_states=outputs.encoder_hidden_states,
1743
- encoder_attentions=outputs.encoder_attentions,
1744
- )
1745
-
1746
- def prepare_inputs_for_generation(
1747
- self,
1748
- decoder_input_ids,
1749
- past=None,
1750
- attention_mask=None,
1751
- head_mask=None,
1752
- decoder_head_mask=None,
1753
- cross_attn_head_mask=None,
1754
- use_cache=None,
1755
- encoder_outputs=None,
1756
- **kwargs
1757
- ):
1758
- # cut decoder_input_ids if past is used
1759
- if past is not None:
1760
- decoder_input_ids = decoder_input_ids[:, -1:]
1761
-
1762
- return {
1763
- "input_ids": None, # encoder_outputs is defined. input_ids not needed
1764
- "encoder_outputs": encoder_outputs,
1765
- "past_key_values": past,
1766
- "decoder_input_ids": decoder_input_ids,
1767
- "attention_mask": attention_mask,
1768
- "head_mask": head_mask,
1769
- "decoder_head_mask": decoder_head_mask,
1770
- "cross_attn_head_mask": cross_attn_head_mask,
1771
- "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
1772
- }
1773
-
1774
- def prepare_decoder_input_ids_from_labels(self, labels):
1775
- return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
1776
-
1777
- @staticmethod
1778
- def _reorder_cache(past, beam_idx):
1779
- reordered_past = ()
1780
- for layer_past in past:
1781
- # cached cross_attention states don't have to be reordered -> they are always the same
1782
- reordered_past += (
1783
- tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
1784
- )
1785
- return reordered_past
1786
-
1787
-
1788
- class LSGBartForSequenceClassification(LSGBartPretrainedModel):
1789
-
1790
- def __init__(self, config, **kwargs):
1791
-
1792
- super().__init__(config, **kwargs)
1793
  self.model = LSGBartModel(config)
1794
  self.classification_head = LSGBartClassificationHead(
1795
  config.d_model,
@@ -1800,115 +1704,12 @@ class LSGBartForSequenceClassification(LSGBartPretrainedModel):
1800
  self.model._init_weights(self.classification_head.dense)
1801
  self.model._init_weights(self.classification_head.out_proj)
1802
 
1803
- def forward(
1804
- self,
1805
- input_ids=None,
1806
- attention_mask=None,
1807
- decoder_input_ids=None,
1808
- decoder_attention_mask=None,
1809
- head_mask=None,
1810
- decoder_head_mask=None,
1811
- cross_attn_head_mask=None,
1812
- encoder_outputs=None,
1813
- inputs_embeds=None,
1814
- decoder_inputs_embeds=None,
1815
- labels=None,
1816
- use_cache=None,
1817
- output_attentions=None,
1818
- output_hidden_states=None,
1819
- return_dict=None,
1820
- ):
1821
-
1822
- r"""
1823
- labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1824
- Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
1825
- config.num_labels - 1]`. If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1826
- """
1827
-
1828
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1829
- if labels is not None:
1830
- use_cache = False
1831
-
1832
- if input_ids is None and inputs_embeds is not None:
1833
- raise NotImplementedError(
1834
- f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
1835
- )
1836
-
1837
- outputs = self.model(
1838
- input_ids,
1839
- attention_mask=attention_mask,
1840
- decoder_input_ids=decoder_input_ids,
1841
- decoder_attention_mask=decoder_attention_mask,
1842
- head_mask=head_mask,
1843
- decoder_head_mask=decoder_head_mask,
1844
- cross_attn_head_mask=cross_attn_head_mask,
1845
- encoder_outputs=encoder_outputs,
1846
- inputs_embeds=inputs_embeds,
1847
- decoder_inputs_embeds=decoder_inputs_embeds,
1848
- use_cache=use_cache,
1849
- output_attentions=output_attentions,
1850
- output_hidden_states=output_hidden_states,
1851
- return_dict=return_dict,
1852
- )
1853
- hidden_states = outputs[0] # last hidden state
1854
-
1855
- eos_mask = input_ids.eq(self.config.eos_token_id)
1856
-
1857
- t, t_ = eos_mask.size()[-1], hidden_states.size()[-2]
1858
- if t > t_:
1859
- eos_mask = eos_mask[:, :t_]
1860
-
1861
- if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
1862
- raise ValueError("All examples must have the same number of <eos> tokens.")
1863
- sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
1864
- :, -1, :
1865
- ]
1866
- logits = self.classification_head(sentence_representation)
1867
-
1868
- loss = None
1869
- if labels is not None:
1870
- if self.config.problem_type is None:
1871
- if self.config.num_labels == 1:
1872
- self.config.problem_type = "regression"
1873
- elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1874
- self.config.problem_type = "single_label_classification"
1875
- else:
1876
- self.config.problem_type = "multi_label_classification"
1877
-
1878
- if self.config.problem_type == "regression":
1879
- loss_fct = MSELoss()
1880
- if self.config.num_labels == 1:
1881
- loss = loss_fct(logits.squeeze(), labels.squeeze())
1882
- else:
1883
- loss = loss_fct(logits, labels)
1884
- elif self.config.problem_type == "single_label_classification":
1885
- loss_fct = CrossEntropyLoss()
1886
- loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
1887
- elif self.config.problem_type == "multi_label_classification":
1888
- loss_fct = BCEWithLogitsLoss()
1889
- loss = loss_fct(logits, labels)
1890
- if not return_dict:
1891
- output = (logits,) + outputs[1:]
1892
- return ((loss,) + output) if loss is not None else output
1893
-
1894
- return Seq2SeqSequenceClassifierOutput(
1895
- loss=loss,
1896
- logits=logits,
1897
- past_key_values=outputs.past_key_values,
1898
- decoder_hidden_states=outputs.decoder_hidden_states,
1899
- decoder_attentions=outputs.decoder_attentions,
1900
- cross_attentions=outputs.cross_attentions,
1901
- encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1902
- encoder_hidden_states=outputs.encoder_hidden_states,
1903
- encoder_attentions=outputs.encoder_attentions,
1904
- )
1905
 
 
1906
 
1907
- class LSGBartForQuestionAnswering(LSGBartPretrainedModel):
1908
 
1909
- def __init__(self, config):
1910
-
1911
- super().__init__(config)
1912
 
1913
  config.num_labels = 2
1914
  self.num_labels = config.num_labels
@@ -1918,102 +1719,6 @@ class LSGBartForQuestionAnswering(LSGBartPretrainedModel):
1918
 
1919
  self.model._init_weights(self.qa_outputs)
1920
 
1921
- def forward(
1922
- self,
1923
- input_ids=None,
1924
- attention_mask=None,
1925
- decoder_input_ids=None,
1926
- decoder_attention_mask=None,
1927
- head_mask=None,
1928
- decoder_head_mask=None,
1929
- cross_attn_head_mask=None,
1930
- encoder_outputs=None,
1931
- start_positions=None,
1932
- end_positions=None,
1933
- inputs_embeds=None,
1934
- decoder_inputs_embeds=None,
1935
- use_cache=None,
1936
- output_attentions=None,
1937
- output_hidden_states=None,
1938
- return_dict=None,
1939
- ):
1940
-
1941
- r"""
1942
- start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1943
- Labels for position (index) of the start of the labelled span for computing the token classification loss.
1944
- Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1945
- are not taken into account for computing the loss.
1946
- end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1947
- Labels for position (index) of the end of the labelled span for computing the token classification loss.
1948
- Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1949
- are not taken into account for computing the loss.
1950
- """
1951
-
1952
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1953
- if start_positions is not None and end_positions is not None:
1954
- use_cache = False
1955
-
1956
- outputs = self.model(
1957
- input_ids,
1958
- attention_mask=attention_mask,
1959
- decoder_input_ids=decoder_input_ids,
1960
- decoder_attention_mask=decoder_attention_mask,
1961
- head_mask=head_mask,
1962
- decoder_head_mask=decoder_head_mask,
1963
- cross_attn_head_mask=cross_attn_head_mask,
1964
- encoder_outputs=encoder_outputs,
1965
- inputs_embeds=inputs_embeds,
1966
- decoder_inputs_embeds=decoder_inputs_embeds,
1967
- use_cache=use_cache,
1968
- output_attentions=output_attentions,
1969
- output_hidden_states=output_hidden_states,
1970
- return_dict=return_dict,
1971
- )
1972
-
1973
- sequence_output = outputs[0]
1974
-
1975
- logits = self.qa_outputs(sequence_output)
1976
- start_logits, end_logits = logits.split(1, dim=-1)
1977
- start_logits = start_logits.squeeze(-1).contiguous()
1978
- end_logits = end_logits.squeeze(-1).contiguous()
1979
-
1980
- total_loss = None
1981
- if start_positions is not None and end_positions is not None:
1982
- # If we are on multi-GPU, split add a dimension
1983
- if len(start_positions.size()) > 1:
1984
- start_positions = start_positions.squeeze(-1)
1985
- if len(end_positions.size()) > 1:
1986
- end_positions = end_positions.squeeze(-1)
1987
- # sometimes the start/end positions are outside our model inputs, we ignore these terms
1988
- ignored_index = start_logits.size(1)
1989
- start_positions = start_positions.clamp(0, ignored_index)
1990
- end_positions = end_positions.clamp(0, ignored_index)
1991
-
1992
- loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1993
- start_loss = loss_fct(start_logits, start_positions)
1994
- end_loss = loss_fct(end_logits, end_positions)
1995
- total_loss = (start_loss + end_loss) / 2
1996
-
1997
- if not return_dict:
1998
- output = (
1999
- start_logits,
2000
- end_logits,
2001
- ) + outputs[1:]
2002
- return ((total_loss,) + output) if total_loss is not None else output
2003
-
2004
- return Seq2SeqQuestionAnsweringModelOutput(
2005
- loss=total_loss,
2006
- start_logits=start_logits,
2007
- end_logits=end_logits,
2008
- past_key_values=outputs.past_key_values,
2009
- decoder_hidden_states=outputs.decoder_hidden_states,
2010
- decoder_attentions=outputs.decoder_attentions,
2011
- cross_attentions=outputs.cross_attentions,
2012
- encoder_last_hidden_state=outputs.encoder_last_hidden_state,
2013
- encoder_hidden_states=outputs.encoder_hidden_states,
2014
- encoder_attentions=outputs.encoder_attentions,
2015
- )
2016
-
2017
 
2018
  class LSGBartDecoderWrapper(LSGBartPretrainedModel):
2019
  """
@@ -2021,7 +1726,7 @@ class LSGBartDecoderWrapper(LSGBartPretrainedModel):
2021
  used in combination with the :class:`~transformers.EncoderDecoderModel` framework.
2022
  """
2023
 
2024
- def __init__(self, config):
2025
  super().__init__(config)
2026
  self.decoder = LSGBartDecoder(config)
2027
 
@@ -2029,14 +1734,14 @@ class LSGBartDecoderWrapper(LSGBartPretrainedModel):
2029
  return self.decoder(*args, **kwargs)
2030
 
2031
 
2032
- class LSGBartForCausalLM(LSGBartPretrainedModel):
2033
 
2034
- def __init__(self, config):
2035
 
2036
- super().__init__(config)
2037
  config = copy.deepcopy(config)
2038
  config.is_decoder = True
2039
  config.is_encoder_decoder = False
 
2040
  self.model = LSGBartDecoderWrapper(config)
2041
 
2042
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
@@ -2044,105 +1749,6 @@ class LSGBartForCausalLM(LSGBartPretrainedModel):
2044
  # Initialize weights and apply final processing
2045
  self.post_init()
2046
 
2047
- def get_input_embeddings(self):
2048
- return self.model.decoder.embed_tokens
2049
-
2050
- def set_input_embeddings(self, value):
2051
- self.model.decoder.embed_tokens = value
2052
-
2053
- def get_output_embeddings(self):
2054
- return self.lm_head
2055
-
2056
- def set_output_embeddings(self, new_embeddings):
2057
- self.lm_head = new_embeddings
2058
-
2059
- def set_decoder(self, decoder):
2060
- self.model.decoder = decoder
2061
-
2062
- def get_decoder(self):
2063
- return self.model.decoder
2064
-
2065
- def forward(
2066
- self,
2067
- input_ids=None,
2068
- attention_mask=None,
2069
- encoder_hidden_states=None,
2070
- encoder_attention_mask=None,
2071
- head_mask=None,
2072
- cross_attn_head_mask=None,
2073
- past_key_values=None,
2074
- inputs_embeds=None,
2075
- labels=None,
2076
- use_cache=None,
2077
- output_attentions=None,
2078
- output_hidden_states=None,
2079
- return_dict=None,
2080
- ):
2081
-
2082
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
2083
- output_hidden_states = (
2084
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
2085
- )
2086
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
2087
-
2088
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
2089
- outputs = self.model.decoder(
2090
- input_ids=input_ids,
2091
- attention_mask=attention_mask,
2092
- encoder_hidden_states=encoder_hidden_states,
2093
- encoder_attention_mask=encoder_attention_mask,
2094
- head_mask=head_mask,
2095
- cross_attn_head_mask=cross_attn_head_mask,
2096
- past_key_values=past_key_values,
2097
- inputs_embeds=inputs_embeds,
2098
- use_cache=use_cache,
2099
- output_attentions=output_attentions,
2100
- output_hidden_states=output_hidden_states,
2101
- return_dict=return_dict,
2102
- )
2103
-
2104
- logits = self.lm_head(outputs[0])
2105
-
2106
- loss = None
2107
- if labels is not None:
2108
- loss_fct = CrossEntropyLoss()
2109
- loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
2110
-
2111
- if not return_dict:
2112
- output = (logits,) + outputs[1:]
2113
- return (loss,) + output if loss is not None else output
2114
-
2115
- return CausalLMOutputWithCrossAttentions(
2116
- loss=loss,
2117
- logits=logits,
2118
- past_key_values=outputs.past_key_values,
2119
- hidden_states=outputs.hidden_states,
2120
- attentions=outputs.attentions,
2121
- cross_attentions=outputs.cross_attentions,
2122
- )
2123
-
2124
- def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
2125
- # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
2126
- if attention_mask is None:
2127
- attention_mask = input_ids.new_ones(input_ids.shape)
2128
-
2129
- if past:
2130
- input_ids = input_ids[:, -1:]
2131
- # first step, decoder_cached_states are empty
2132
- return {
2133
- "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
2134
- "attention_mask": attention_mask,
2135
- "past_key_values": past,
2136
- "use_cache": use_cache,
2137
- }
2138
-
2139
- @staticmethod
2140
- def _reorder_cache(past, beam_idx):
2141
- reordered_past = ()
2142
- for layer_past in past:
2143
- reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
2144
- return reordered_past
2145
-
2146
 
2147
  def str_to_class(classname):
2148
  return getattr(sys.modules[__name__], classname)
 
54
  self.sparsity_factor = sparsity_factor
55
  self.sparsity_type = sparsity_type
56
 
57
+ if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride"]:
58
  logger.warning(
59
+ "[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride'], setting sparsity_type=None, computation will skip sparse attention")
60
  self.sparsity_type = None
61
 
62
+ if self.sparsity_type in ["stride", "block_stride"]:
63
  if self.sparsity_factor > self.encoder_attention_heads:
64
  logger.warning(
65
+ "[WARNING CONFIG]: sparsity_factor > encoder_attention_heads is not recommended for stride/block_stride sparsity"
66
  )
67
 
68
+ if self.num_global_tokens < 1:
69
+ logger.warning(
70
+ "[WARNING CONFIG]: num_global_tokens < 1 is not compatible, setting num_global_tokens=1"
71
+ )
72
+ self.num_global_tokens = 1
73
+ elif self.num_global_tokens > 512:
74
+ logger.warning(
75
+ "[WARNING CONFIG]: num_global_tokens > 512 is not compatible, setting num_global_tokens=512"
76
+ )
77
+ self.num_global_tokens = 512
78
+
79
+ if self.sparsity_factor > 0:
80
+ assert self.block_size % self.sparsity_factor == 0, "[ERROR CONFIG]: block_size must be divisible by sparsity_factor"
81
+ assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
82
+
83
 
84
  def shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id):
85
  """
 
232
  # Shape of blocks
233
  self.local_shapes = (self.block_size*3, self.block_size)
234
  if self.sparse_block_size and self.sparsity_factor > 0:
 
 
235
  self.sparse_shapes = (self.sparse_block_size*3, self.block_size//self.sparsity_factor)
236
 
237
  self.attention = BaseAttentionProduct(config)
 
412
  "pooling": self.get_sparse_tokens_with_pooling,
413
  "lsh": self.get_sparse_tokens_with_lsh,
414
  "stride": self.get_sparse_tokens_with_stride,
415
+ "block_stride": self.get_sparse_tokens_with_block_stride,
416
  }
417
 
418
  self.sparsity_type = config.sparsity_type
 
424
  def get_sparse_tokens_with_norm(self, keys, values, mask):
425
 
426
  if self.sparsity_factor == 1:
427
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
428
 
429
  with torch.no_grad():
430
 
 
452
  def get_sparse_tokens_with_pooling(self, keys, values, mask):
453
 
454
  if self.sparsity_factor == 1:
455
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
456
 
457
  keys = self.chunk(keys, self.sparsity_factor)
458
  values = self.chunk(values, self.sparsity_factor)
 
474
  def get_sparse_tokens_with_stride(self, keys, values, mask):
475
 
476
  if self.sparsity_factor == 1:
477
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
478
 
479
  n, h, t, d = keys.size()
480
  sparse_idx = torch.arange(t // self.sparsity_factor, device=keys.device) * self.sparsity_factor
 
487
 
488
  return keys, values, mask
489
 
490
+ def get_sparse_tokens_with_block_stride(self, keys, values, mask):
491
+
492
+ if self.sparsity_factor == 1:
493
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
494
+
495
+ n, h, t, d = keys.size()
496
+
497
+ t, b = self.block_size, t // self.block_size
498
+ sparse_idx = torch.arange(t // self.sparsity_factor, device=keys.device)
499
+ sparse_idx = sparse_idx.reshape(1, 1, 1, -1, 1) + torch.arange(h, device=keys.device).reshape(1, h, 1, 1, 1) * (t // self.sparsity_factor)
500
+ sparse_idx = (sparse_idx % t)
501
+ sparse_idx = sparse_idx + torch.arange(b, device=keys.device).reshape(1, 1, -1, 1, 1) * t
502
+ sparse_idx = sparse_idx.reshape(1, h, -1, 1).expand(n, h, -1, 1)
503
+
504
+ keys = keys.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
505
+ values = values.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
506
+ mask = mask.expand(-1, h, -1, -1).transpose(-1, -2).gather(dim=-2, index=sparse_idx).transpose(-1, -2)
507
+
508
+ return keys, values, mask
509
+
510
  def get_sparse_tokens_with_lsh(self, keys, values, mask):
511
 
512
  if self.sparsity_factor == 1:
513
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
514
 
515
  block_size = min(self.block_size, self.sparse_block_size)
516
  keys = self.chunk(keys, block_size)
 
1341
  self.padding_idx = config.pad_token_id
1342
  self.max_target_positions = config.max_position_embeddings
1343
  self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
1344
+ self.adaptive = config.adaptive
1345
 
1346
  if embed_tokens is not None:
1347
  self.embed_tokens = embed_tokens
 
1384
 
1385
  return combined_attention_mask
1386
 
1387
+ def resize_inputs(self, inputs_embeds, attention_mask):
1388
+ pad = 0
1389
+
1390
+ max_len = int(attention_mask.sum(dim=-1).max())
1391
+ pad = attention_mask.size()[-1] - max_len
1392
+ inputs_embeds = inputs_embeds[:, :max_len]
1393
+ attention_mask = attention_mask[..., :max_len]
1394
+ return pad, inputs_embeds, attention_mask
1395
+
1396
  def forward(
1397
  self,
1398
  input_ids=None,
 
1433
  if inputs_embeds is None:
1434
  inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
1435
 
1436
+ # Resize to reduce computation
1437
+ pad = 0
1438
+ if self.adaptive:
1439
+ if attention_mask is not None:
1440
+ pad, inputs_embeds, attention_mask = self.resize_inputs(inputs_embeds, attention_mask)
1441
+ input_shape = inputs_embeds.size()[:-1]
1442
+ if encoder_attention_mask is not None:
1443
+ _, encoder_hidden_states, encoder_attention_mask = self.resize_inputs(encoder_hidden_states, encoder_attention_mask)
1444
 
1445
  attention_mask = self._prepare_decoder_attention_mask(
1446
  attention_mask, input_shape, inputs_embeds, past_key_values_length
 
1534
  if encoder_hidden_states is not None:
1535
  all_cross_attentions += (layer_outputs[2],)
1536
 
1537
+ # Resize to original shape
1538
+ hidden_states = torch.nn.functional.pad(hidden_states.transpose(-1, -2), pad=(0, pad), value=0).transpose(-1, -2)
1539
+
1540
  # add hidden states from the last decoder layer
1541
  if output_hidden_states:
1542
  all_hidden_states += (hidden_states,)
 
1673
  )
1674
 
1675
 
1676
+ class LSGBartForConditionalGeneration(BartForConditionalGeneration, LSGBartPretrainedModel):
1677
 
1678
  base_model_prefix = "model"
1679
  _keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head\.weight"]
1680
 
1681
  def __init__(self, config):
1682
 
1683
+ LSGBartPretrainedModel.__init__(self, config)
1684
  self.model = LSGBartModel(config)
1685
  self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
1686
  self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
 
1688
  # Initialize weights and apply final processing
1689
  self.post_init()
1690
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1691
 
1692
+ class LSGBartForSequenceClassification(BartForSequenceClassification, LSGBartPretrainedModel):
1693
 
1694
+ def __init__(self, config: LSGBartConfig, **kwargs):
 
 
 
 
1695
 
1696
+ LSGBartPretrainedModel.__init__(self, config, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1697
  self.model = LSGBartModel(config)
1698
  self.classification_head = LSGBartClassificationHead(
1699
  config.d_model,
 
1704
  self.model._init_weights(self.classification_head.dense)
1705
  self.model._init_weights(self.classification_head.out_proj)
1706
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1707
 
1708
+ class LSGBartForQuestionAnswering(BartForQuestionAnswering, LSGBartPretrainedModel):
1709
 
1710
+ def __init__(self, config: LSGBartConfig):
1711
 
1712
+ LSGBartPretrainedModel.__init__(self, config)
 
 
1713
 
1714
  config.num_labels = 2
1715
  self.num_labels = config.num_labels
 
1719
 
1720
  self.model._init_weights(self.qa_outputs)
1721
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1722
 
1723
  class LSGBartDecoderWrapper(LSGBartPretrainedModel):
1724
  """
 
1726
  used in combination with the :class:`~transformers.EncoderDecoderModel` framework.
1727
  """
1728
 
1729
+ def __init__(self, config: LSGBartConfig):
1730
  super().__init__(config)
1731
  self.decoder = LSGBartDecoder(config)
1732
 
 
1734
  return self.decoder(*args, **kwargs)
1735
 
1736
 
1737
+ class LSGBartForCausalLM(BartForCausalLM, LSGBartPretrainedModel):
1738
 
1739
+ def __init__(self, config: LSGBartConfig):
1740
 
 
1741
  config = copy.deepcopy(config)
1742
  config.is_decoder = True
1743
  config.is_encoder_decoder = False
1744
+ LSGBartPretrainedModel.__init__(self, config)
1745
  self.model = LSGBartDecoderWrapper(config)
1746
 
1747
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 
1749
  # Initialize weights and apply final processing
1750
  self.post_init()
1751
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1752
 
1753
  def str_to_class(classname):
1754
  return getattr(sys.modules[__name__], classname)
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:aaa7c2f67804eb3b01d0b555a6e9600fa80b188a9506560aa88ab19008a2e4b7
3
  size 653914167
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b88fc0094b185dff97f0e9d44c155c561da2a130efd7c1860a1af192272ef286
3
  size 653914167
tokenizer.json CHANGED
@@ -6,16 +6,7 @@
6
  "strategy": "LongestFirst",
7
  "stride": 0
8
  },
9
- "padding": {
10
- "strategy": {
11
- "Fixed": 320
12
- },
13
- "direction": "Right",
14
- "pad_to_multiple_of": null,
15
- "pad_id": 1,
16
- "pad_type_id": 0,
17
- "pad_token": "<pad>"
18
- },
19
  "added_tokens": [
20
  {
21
  "id": 0,
 
6
  "strategy": "LongestFirst",
7
  "stride": 0
8
  },
9
+ "padding": null,
 
 
 
 
 
 
 
 
 
10
  "added_tokens": [
11
  {
12
  "id": 0,
tokenizer_config.json CHANGED
@@ -1 +1 @@
1
- {"errors": "replace", "bos_token": "<s>", "eos_token": "</s>", "sep_token": "</s>", "cls_token": "<s>", "unk_token": "<unk>", "pad_token": "<pad>", "mask_token": "<mask>", "add_prefix_space": false, "trim_offsets": true, "model_max_length": 16384, "special_tokens_map_file": null, "name_or_path": "/data/ccondevaux/lsg/text-summarization/tmp/arxiv/lsg_local_large_lr_16384_full_trained", "tokenizer_class": "BartTokenizer"}
 
1
+ {"errors": "replace", "bos_token": "<s>", "eos_token": "</s>", "sep_token": "</s>", "cls_token": "<s>", "unk_token": "<unk>", "pad_token": "<pad>", "mask_token": "<mask>", "add_prefix_space": false, "trim_offsets": true, "model_max_length": 16384, "special_tokens_map_file": null, "name_or_path": "/data/ccondevaux/lsg/text-summarization/tmp_final/arxiv/lsg_local_16384_trained", "tokenizer_class": "BartTokenizer"}