ccdv commited on
Commit
c9358be
1 Parent(s): 83d2fce
README.md CHANGED
@@ -18,22 +18,47 @@ should probably proofread and complete it, then remove this comment. -->
18
  **This model relies on a custom modeling file, you need to add trust_remote_code=True**\
19
  **See [\#13467](https://github.com/huggingface/transformers/pull/13467)**
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  # ccdv/lsg-bart-base-16384-pubmed
22
 
23
- This model is a fine-tuned version of [ccdv/lsg-bart-base-4096-pubmed](https://huggingface.co/ccdv/lsg-bart-base-4096-pubmed) on the scientific_papers pubmed dataset. \
 
24
  It achieves the following results on the test set:
25
 
26
- | Length | Global tokens | Sparse Type | Block Size | Sparsity | Connexions | R1 | R2 | RL | RLsum |
27
  |:------ |:------------- |:----------- |:---------- |:-------- | :--------- |:----- |:----- |:----- |:----- |
28
- | 16384 | 64 | - | 256 | 0 | 768 | 48.29 | 22.53 | 29.35 | 44.55 |
 
 
 
29
 
 
 
 
30
 
31
  ## Model description
32
  The model relies on Local-Sparse-Global attention to handle long sequences:
33
  ![attn](attn.png)
34
 
35
  The model has about ~145 millions parameters (6 encoder layers - 6 decoder layers). \
36
- The model is warm started from [ccdv/lsg-bart-base-4096-pubmed](https://huggingface.co/ccdv/lsg-bart-base-4096-pubmed), converted to handle long sequences (encoder only) and fine tuned. \
37
 
38
  ## Intended uses & limitations
39
 
@@ -49,12 +74,13 @@ More information needed
49
 
50
  The following hyperparameters were used during training:
51
  - learning_rate: 8e-05
52
- - train_batch_size: 1
53
  - seed: 42
54
- - gradient_accumulation_steps: 32
55
  - total_train_batch_size: 32
56
  - optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
57
  - lr_scheduler_type: linear
 
58
  - num_epochs: 1.0
59
 
60
  ### Generate hyperparameters
@@ -62,14 +88,14 @@ The following hyperparameters were used during training:
62
  The following hyperparameters were used during generation:
63
  - dataset_name: scientific_papers
64
  - dataset_config_name: pubmed
65
- - eval_batch_size: 2
 
66
  - early_stopping: True
67
  - ignore_pad_token_for_loss: True
68
  - length_penalty: 2.0
69
  - max_length: 512
70
  - min_length: 128
71
  - num_beams: 5
72
- - num_samples: None
73
  - no_repeat_ngram_size: None
74
  - seed: 123
75
 
 
18
  **This model relies on a custom modeling file, you need to add trust_remote_code=True**\
19
  **See [\#13467](https://github.com/huggingface/transformers/pull/13467)**
20
 
21
+ ```python
22
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
23
+
24
+ tokenizer = AutoTokenizer.from_pretrained("ccdv/lsg-bart-base-16384-pubmed", trust_remote_code=True)
25
+ model = AutoModelForSeq2SeqLM.from_pretrained("ccdv/lsg-bart-base-16384-pubmed", trust_remote_code=True)
26
+
27
+ text = "Replace by what you want."
28
+ pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer, device=0)
29
+ generated_text = pipe(
30
+ text,
31
+ truncation=True,
32
+ max_length=64,
33
+ no_repeat_ngram_size=7,
34
+ num_beams=2,
35
+ early_stopping=True
36
+ )
37
+ ```
38
+
39
  # ccdv/lsg-bart-base-16384-pubmed
40
 
41
+ This model is a fine-tuned version of [ccdv/lsg-bart-base-4096-pubmed](https://huggingface.co/ccdv/lsg-bart-base-4096-pubmed) on the [scientific_papers pubmed](https://huggingface.co/datasets/scientific_papers) dataset. \
42
+ The model is converted to handle 16384 long sequences and fine-tuned accordingly during 1 epoch. \
43
  It achieves the following results on the test set:
44
 
45
+ | Length | Global tokens | Fine-tuning | Block Size | Sparsity | Connexions | R1 | R2 | RL | RLsum |
46
  |:------ |:------------- |:----------- |:---------- |:-------- | :--------- |:----- |:----- |:----- |:----- |
47
+ | 16384 | 64 | Full | 256 | 0 | 768 | 48.32 | 22.52 | 29.36 | 44.57 |
48
+ | 16384 | 1 | None | 256 | 0 | 768 | 48.03 | 22.42 | 29.28 | 44.32 |
49
+
50
+ Reference model:
51
 
52
+ | Length | Global tokens | Fine-tuning | Block Size | Sparsity | Connexions | R1 | R2 | RL | RLsum |
53
+ |:------ |:------------- |:----------- |:---------- |:-------- | :--------- |:----- |:----- |:----- |:----- |
54
+ | 4096 | 1 | - | 256 | 0 | 768 | 47.37 | 21.74 | 28.59 | 43.67 |
55
 
56
  ## Model description
57
  The model relies on Local-Sparse-Global attention to handle long sequences:
58
  ![attn](attn.png)
59
 
60
  The model has about ~145 millions parameters (6 encoder layers - 6 decoder layers). \
61
+ The model is warm started from [ccdv/lsg-bart-base-4096-pubmed](https://huggingface.co/ccdv/lsg-bart-base-4096-pubmed), converted to handle long sequences (encoder only) and fine tuned.
62
 
63
  ## Intended uses & limitations
64
 
 
74
 
75
  The following hyperparameters were used during training:
76
  - learning_rate: 8e-05
77
+ - train_batch_size: 8
78
  - seed: 42
79
+ - gradient_accumulation_steps: 4
80
  - total_train_batch_size: 32
81
  - optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
82
  - lr_scheduler_type: linear
83
+ - lr_scheduler_warmup_ratio: 0.1
84
  - num_epochs: 1.0
85
 
86
  ### Generate hyperparameters
 
88
  The following hyperparameters were used during generation:
89
  - dataset_name: scientific_papers
90
  - dataset_config_name: pubmed
91
+ - eval_batch_size: 4
92
+ - eval_samples: 6658
93
  - early_stopping: True
94
  - ignore_pad_token_for_loss: True
95
  - length_penalty: 2.0
96
  - max_length: 512
97
  - min_length: 128
98
  - num_beams: 5
 
99
  - no_repeat_ngram_size: None
100
  - seed: 123
101
 
all_results.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "eval_gen_len": 337.5673,
3
+ "eval_loss": 1.505071759223938,
4
+ "eval_rouge1": 48.3164,
5
+ "eval_rouge2": 22.5152,
6
+ "eval_rougeL": 29.3582,
7
+ "eval_rougeLsum": 44.5727,
8
+ "eval_runtime": 19004.3788,
9
+ "eval_samples": 6658,
10
+ "eval_samples_per_second": 0.35,
11
+ "eval_steps_per_second": 0.088
12
+ }
config.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
- "_name_or_path": "ccdv/lsg-bart-base-16384-pubmed",
3
  "activation_dropout": 0.1,
4
  "activation_function": "gelu",
5
  "adaptive": true,
@@ -67,8 +67,8 @@
67
  "pool_with_global": true,
68
  "scale_embedding": false,
69
  "sparse_block_size": 0,
70
- "sparsity_factor": 2,
71
- "sparsity_type": "norm",
72
  "task_specific_params": {
73
  "summarization": {
74
  "length_penalty": 1.0,
@@ -90,7 +90,7 @@
90
  }
91
  },
92
  "torch_dtype": "float32",
93
- "transformers_version": "4.18.0",
94
  "use_cache": true,
95
  "vocab_size": 50265
96
  }
 
1
  {
2
+ "_name_or_path": "/data/ccondevaux/lsg/text-summarization/tmp_final/pubmed/lsg_local_16384_trained",
3
  "activation_dropout": 0.1,
4
  "activation_function": "gelu",
5
  "adaptive": true,
 
67
  "pool_with_global": true,
68
  "scale_embedding": false,
69
  "sparse_block_size": 0,
70
+ "sparsity_factor": 4,
71
+ "sparsity_type": "none",
72
  "task_specific_params": {
73
  "summarization": {
74
  "length_penalty": 1.0,
 
90
  }
91
  },
92
  "torch_dtype": "float32",
93
+ "transformers_version": "4.19.2",
94
  "use_cache": true,
95
  "vocab_size": 50265
96
  }
eval_results.json CHANGED
@@ -1,12 +1,12 @@
1
  {
2
- "eval_gen_len": 337.3282,
3
- "eval_loss": 1.5187041759490967,
4
- "eval_rouge1": 48.2871,
5
- "eval_rouge2": 22.5259,
6
- "eval_rougeL": 29.3512,
7
- "eval_rougeLsum": 44.5493,
8
- "eval_runtime": 32015.2043,
9
  "eval_samples": 6658,
10
- "eval_samples_per_second": 0.208,
11
- "eval_steps_per_second": 0.104
12
  }
 
1
  {
2
+ "eval_gen_len": 337.5673,
3
+ "eval_loss": 1.505071759223938,
4
+ "eval_rouge1": 48.3164,
5
+ "eval_rouge2": 22.5152,
6
+ "eval_rougeL": 29.3582,
7
+ "eval_rougeLsum": 44.5727,
8
+ "eval_runtime": 19004.3788,
9
  "eval_samples": 6658,
10
+ "eval_samples_per_second": 0.35,
11
+ "eval_steps_per_second": 0.088
12
  }
modeling_lsg_bart.py CHANGED
@@ -41,8 +41,6 @@ class LSGBartConfig(BartConfig):
41
  ):
42
  """Constructs LSGConfig."""
43
  super().__init__(**kwargs)
44
-
45
- assert sparsity_type in ["norm", "lsh", "pooling", "stride"], "Sparsity mode must be 'norm', 'lsh' or 'pooling'"
46
 
47
  self.adaptive = adaptive
48
  self.auto_map = AUTO_MAP
@@ -55,7 +53,33 @@ class LSGBartConfig(BartConfig):
55
  self.sparse_block_size = sparse_block_size
56
  self.sparsity_factor = sparsity_factor
57
  self.sparsity_type = sparsity_type
 
 
 
 
 
 
 
 
 
 
 
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  def shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id):
61
  """
@@ -208,8 +232,6 @@ class LSGAttentionProduct(nn.Module):
208
  # Shape of blocks
209
  self.local_shapes = (self.block_size*3, self.block_size)
210
  if self.sparse_block_size and self.sparsity_factor > 0:
211
- assert self.block_size % self.sparsity_factor == 0, "block_size must be divisible by sparsity_factor"
212
- assert self.block_size//self.sparsity_factor >= 1, "Config is wrong, make sure block_size >= sparsity_factor"
213
  self.sparse_shapes = (self.sparse_block_size*3, self.block_size//self.sparsity_factor)
214
 
215
  self.attention = BaseAttentionProduct(config)
@@ -390,24 +412,19 @@ class LSGBartEncoderAttention(BaseSelfAttention):
390
  "pooling": self.get_sparse_tokens_with_pooling,
391
  "lsh": self.get_sparse_tokens_with_lsh,
392
  "stride": self.get_sparse_tokens_with_stride,
 
393
  }
394
 
395
  self.sparsity_type = config.sparsity_type
396
- self.get_sparse_elements = sparse_functions[self.sparsity_type]
397
-
398
- if config.sparsity_type == "stride":
399
- if config.sparsity_factor > config.encoder_attention_heads:
400
- logger.warning(
401
- "Warning: sparsity_factor > encoder_attention_heads is not recommended for stride sparsity"
402
- )
403
 
404
  if config.sparsity_type == "lsh":
405
  self.lsh_num_pre_rounds = config.lsh_num_pre_rounds
406
-
407
  def get_sparse_tokens_with_norm(self, keys, values, mask):
408
 
409
  if self.sparsity_factor == 1:
410
- return keys, values, mask
411
 
412
  with torch.no_grad():
413
 
@@ -435,7 +452,7 @@ class LSGBartEncoderAttention(BaseSelfAttention):
435
  def get_sparse_tokens_with_pooling(self, keys, values, mask):
436
 
437
  if self.sparsity_factor == 1:
438
- return keys, values, mask
439
 
440
  keys = self.chunk(keys, self.sparsity_factor)
441
  values = self.chunk(values, self.sparsity_factor)
@@ -457,7 +474,7 @@ class LSGBartEncoderAttention(BaseSelfAttention):
457
  def get_sparse_tokens_with_stride(self, keys, values, mask):
458
 
459
  if self.sparsity_factor == 1:
460
- return keys, values, mask
461
 
462
  n, h, t, d = keys.size()
463
  sparse_idx = torch.arange(t // self.sparsity_factor, device=keys.device) * self.sparsity_factor
@@ -470,10 +487,30 @@ class LSGBartEncoderAttention(BaseSelfAttention):
470
 
471
  return keys, values, mask
472
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
473
  def get_sparse_tokens_with_lsh(self, keys, values, mask):
474
 
475
  if self.sparsity_factor == 1:
476
- return keys, values, mask
477
 
478
  block_size = min(self.block_size, self.sparse_block_size)
479
  keys = self.chunk(keys, block_size)
@@ -490,9 +527,9 @@ class LSGBartEncoderAttention(BaseSelfAttention):
490
  extra_factor = 1
491
 
492
  for _ in range(self.lsh_num_pre_rounds):
493
- keys, values, mask = self.lsg_round(keys, values, mask, t*extra_factor)
494
 
495
- keys, values, mask = self.lsg_round(keys, values, mask, t//self.sparsity_factor)
496
  keys /= mask + 1e-8
497
  values /= mask + 1e-8
498
 
@@ -500,7 +537,7 @@ class LSGBartEncoderAttention(BaseSelfAttention):
500
 
501
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
502
 
503
- def lsg_round(self, keys, values, mask, output_size):
504
 
505
  with torch.no_grad():
506
 
@@ -1304,6 +1341,7 @@ class LSGBartDecoder(LSGBartPretrainedModel):
1304
  self.padding_idx = config.pad_token_id
1305
  self.max_target_positions = config.max_position_embeddings
1306
  self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
 
1307
 
1308
  if embed_tokens is not None:
1309
  self.embed_tokens = embed_tokens
@@ -1346,6 +1384,15 @@ class LSGBartDecoder(LSGBartPretrainedModel):
1346
 
1347
  return combined_attention_mask
1348
 
 
 
 
 
 
 
 
 
 
1349
  def forward(
1350
  self,
1351
  input_ids=None,
@@ -1386,12 +1433,14 @@ class LSGBartDecoder(LSGBartPretrainedModel):
1386
  if inputs_embeds is None:
1387
  inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
1388
 
1389
- # Cut
1390
- if attention_mask is not None:
1391
- max_len = int(attention_mask.sum(dim=-1).max())
1392
- inputs_embeds = inputs_embeds[:, :max_len]
1393
- attention_mask = attention_mask[..., :max_len]
1394
- input_shape = inputs_embeds.size()[:-1]
 
 
1395
 
1396
  attention_mask = self._prepare_decoder_attention_mask(
1397
  attention_mask, input_shape, inputs_embeds, past_key_values_length
@@ -1485,6 +1534,9 @@ class LSGBartDecoder(LSGBartPretrainedModel):
1485
  if encoder_hidden_states is not None:
1486
  all_cross_attentions += (layer_outputs[2],)
1487
 
 
 
 
1488
  # add hidden states from the last decoder layer
1489
  if output_hidden_states:
1490
  all_hidden_states += (hidden_states,)
@@ -1621,14 +1673,14 @@ class LSGBartModel(LSGBartPretrainedModel):
1621
  )
1622
 
1623
 
1624
- class LSGBartForConditionalGeneration(LSGBartPretrainedModel):
1625
 
1626
  base_model_prefix = "model"
1627
  _keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head\.weight"]
1628
 
1629
  def __init__(self, config):
1630
 
1631
- super().__init__(config)
1632
  self.model = LSGBartModel(config)
1633
  self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
1634
  self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
@@ -1636,157 +1688,12 @@ class LSGBartForConditionalGeneration(LSGBartPretrainedModel):
1636
  # Initialize weights and apply final processing
1637
  self.post_init()
1638
 
1639
- def get_encoder(self):
1640
- return self.model.get_encoder()
1641
-
1642
- def get_decoder(self):
1643
- return self.model.get_decoder()
1644
 
1645
- def resize_token_embeddings(self, new_num_tokens):
1646
- new_embeddings = super().resize_token_embeddings(new_num_tokens)
1647
- self._resize_final_logits_bias(new_num_tokens)
1648
- return new_embeddings
1649
 
1650
- def _resize_final_logits_bias(self, new_num_tokens):
1651
- old_num_tokens = self.final_logits_bias.shape[-1]
1652
- if new_num_tokens <= old_num_tokens:
1653
- new_bias = self.final_logits_bias[:, :new_num_tokens]
1654
- else:
1655
- extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
1656
- new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
1657
- self.register_buffer("final_logits_bias", new_bias)
1658
-
1659
- def get_output_embeddings(self):
1660
- return self.lm_head
1661
-
1662
- def set_output_embeddings(self, new_embeddings):
1663
- self.lm_head = new_embeddings
1664
-
1665
- def forward(
1666
- self,
1667
- input_ids=None,
1668
- attention_mask=None,
1669
- decoder_input_ids=None,
1670
- decoder_attention_mask=None,
1671
- head_mask=None,
1672
- decoder_head_mask=None,
1673
- cross_attn_head_mask=None,
1674
- encoder_outputs=None,
1675
- past_key_values=None,
1676
- inputs_embeds=None,
1677
- decoder_inputs_embeds=None,
1678
- labels=None,
1679
- use_cache=None,
1680
- output_attentions=None,
1681
- output_hidden_states=None,
1682
- return_dict=None,
1683
- ):
1684
 
1685
- r"""
1686
- labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1687
- Labels for computing the masked language modeling loss. Indices should either be in ``[0, ...,
1688
- config.vocab_size]`` or -100 (see ``input_ids`` docstring). Tokens with indices set to ``-100`` are ignored
1689
- (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``.
1690
- Returns:
1691
- """
1692
-
1693
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1694
-
1695
- if labels is not None:
1696
- if decoder_input_ids is None and decoder_inputs_embeds is None:
1697
- decoder_input_ids = shift_tokens_right(
1698
- labels, self.config.pad_token_id, self.config.decoder_start_token_id
1699
- )
1700
-
1701
- outputs = self.model(
1702
- input_ids,
1703
- attention_mask=attention_mask,
1704
- decoder_input_ids=decoder_input_ids,
1705
- encoder_outputs=encoder_outputs,
1706
- decoder_attention_mask=decoder_attention_mask,
1707
- head_mask=head_mask,
1708
- decoder_head_mask=decoder_head_mask,
1709
- cross_attn_head_mask=cross_attn_head_mask,
1710
- past_key_values=past_key_values,
1711
- inputs_embeds=inputs_embeds,
1712
- decoder_inputs_embeds=decoder_inputs_embeds,
1713
- use_cache=use_cache,
1714
- output_attentions=output_attentions,
1715
- output_hidden_states=output_hidden_states,
1716
- return_dict=return_dict,
1717
- )
1718
-
1719
-
1720
- lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
1721
-
1722
- masked_lm_loss = None
1723
- if labels is not None:
1724
- loss_fct = CrossEntropyLoss()
1725
- masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
1726
-
1727
- if not return_dict:
1728
- output = (lm_logits,) + outputs[1:]
1729
- return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1730
-
1731
- return Seq2SeqLMOutput(
1732
- loss=masked_lm_loss,
1733
- logits=lm_logits,
1734
- past_key_values=outputs.past_key_values,
1735
- decoder_hidden_states=outputs.decoder_hidden_states,
1736
- decoder_attentions=outputs.decoder_attentions,
1737
- cross_attentions=outputs.cross_attentions,
1738
- encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1739
- encoder_hidden_states=outputs.encoder_hidden_states,
1740
- encoder_attentions=outputs.encoder_attentions,
1741
- )
1742
-
1743
- def prepare_inputs_for_generation(
1744
- self,
1745
- decoder_input_ids,
1746
- past=None,
1747
- attention_mask=None,
1748
- head_mask=None,
1749
- decoder_head_mask=None,
1750
- cross_attn_head_mask=None,
1751
- use_cache=None,
1752
- encoder_outputs=None,
1753
- **kwargs
1754
- ):
1755
- # cut decoder_input_ids if past is used
1756
- if past is not None:
1757
- decoder_input_ids = decoder_input_ids[:, -1:]
1758
-
1759
- return {
1760
- "input_ids": None, # encoder_outputs is defined. input_ids not needed
1761
- "encoder_outputs": encoder_outputs,
1762
- "past_key_values": past,
1763
- "decoder_input_ids": decoder_input_ids,
1764
- "attention_mask": attention_mask,
1765
- "head_mask": head_mask,
1766
- "decoder_head_mask": decoder_head_mask,
1767
- "cross_attn_head_mask": cross_attn_head_mask,
1768
- "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
1769
- }
1770
-
1771
- def prepare_decoder_input_ids_from_labels(self, labels):
1772
- return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
1773
-
1774
- @staticmethod
1775
- def _reorder_cache(past, beam_idx):
1776
- reordered_past = ()
1777
- for layer_past in past:
1778
- # cached cross_attention states don't have to be reordered -> they are always the same
1779
- reordered_past += (
1780
- tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
1781
- )
1782
- return reordered_past
1783
-
1784
-
1785
- class LSGBartForSequenceClassification(LSGBartPretrainedModel):
1786
-
1787
- def __init__(self, config, **kwargs):
1788
-
1789
- super().__init__(config, **kwargs)
1790
  self.model = LSGBartModel(config)
1791
  self.classification_head = LSGBartClassificationHead(
1792
  config.d_model,
@@ -1797,115 +1704,12 @@ class LSGBartForSequenceClassification(LSGBartPretrainedModel):
1797
  self.model._init_weights(self.classification_head.dense)
1798
  self.model._init_weights(self.classification_head.out_proj)
1799
 
1800
- def forward(
1801
- self,
1802
- input_ids=None,
1803
- attention_mask=None,
1804
- decoder_input_ids=None,
1805
- decoder_attention_mask=None,
1806
- head_mask=None,
1807
- decoder_head_mask=None,
1808
- cross_attn_head_mask=None,
1809
- encoder_outputs=None,
1810
- inputs_embeds=None,
1811
- decoder_inputs_embeds=None,
1812
- labels=None,
1813
- use_cache=None,
1814
- output_attentions=None,
1815
- output_hidden_states=None,
1816
- return_dict=None,
1817
- ):
1818
-
1819
- r"""
1820
- labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1821
- Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
1822
- config.num_labels - 1]`. If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1823
- """
1824
 
1825
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1826
- if labels is not None:
1827
- use_cache = False
1828
 
1829
- if input_ids is None and inputs_embeds is not None:
1830
- raise NotImplementedError(
1831
- f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
1832
- )
1833
 
1834
- outputs = self.model(
1835
- input_ids,
1836
- attention_mask=attention_mask,
1837
- decoder_input_ids=decoder_input_ids,
1838
- decoder_attention_mask=decoder_attention_mask,
1839
- head_mask=head_mask,
1840
- decoder_head_mask=decoder_head_mask,
1841
- cross_attn_head_mask=cross_attn_head_mask,
1842
- encoder_outputs=encoder_outputs,
1843
- inputs_embeds=inputs_embeds,
1844
- decoder_inputs_embeds=decoder_inputs_embeds,
1845
- use_cache=use_cache,
1846
- output_attentions=output_attentions,
1847
- output_hidden_states=output_hidden_states,
1848
- return_dict=return_dict,
1849
- )
1850
- hidden_states = outputs[0] # last hidden state
1851
-
1852
- eos_mask = input_ids.eq(self.config.eos_token_id)
1853
-
1854
- t, t_ = eos_mask.size()[-1], hidden_states.size()[-2]
1855
- if t > t_:
1856
- eos_mask = eos_mask[:, :t_]
1857
-
1858
- if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
1859
- raise ValueError("All examples must have the same number of <eos> tokens.")
1860
- sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
1861
- :, -1, :
1862
- ]
1863
- logits = self.classification_head(sentence_representation)
1864
-
1865
- loss = None
1866
- if labels is not None:
1867
- if self.config.problem_type is None:
1868
- if self.config.num_labels == 1:
1869
- self.config.problem_type = "regression"
1870
- elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1871
- self.config.problem_type = "single_label_classification"
1872
- else:
1873
- self.config.problem_type = "multi_label_classification"
1874
-
1875
- if self.config.problem_type == "regression":
1876
- loss_fct = MSELoss()
1877
- if self.config.num_labels == 1:
1878
- loss = loss_fct(logits.squeeze(), labels.squeeze())
1879
- else:
1880
- loss = loss_fct(logits, labels)
1881
- elif self.config.problem_type == "single_label_classification":
1882
- loss_fct = CrossEntropyLoss()
1883
- loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
1884
- elif self.config.problem_type == "multi_label_classification":
1885
- loss_fct = BCEWithLogitsLoss()
1886
- loss = loss_fct(logits, labels)
1887
- if not return_dict:
1888
- output = (logits,) + outputs[1:]
1889
- return ((loss,) + output) if loss is not None else output
1890
-
1891
- return Seq2SeqSequenceClassifierOutput(
1892
- loss=loss,
1893
- logits=logits,
1894
- past_key_values=outputs.past_key_values,
1895
- decoder_hidden_states=outputs.decoder_hidden_states,
1896
- decoder_attentions=outputs.decoder_attentions,
1897
- cross_attentions=outputs.cross_attentions,
1898
- encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1899
- encoder_hidden_states=outputs.encoder_hidden_states,
1900
- encoder_attentions=outputs.encoder_attentions,
1901
- )
1902
-
1903
-
1904
- class LSGBartForQuestionAnswering(LSGBartPretrainedModel):
1905
-
1906
- def __init__(self, config):
1907
-
1908
- super().__init__(config)
1909
 
1910
  config.num_labels = 2
1911
  self.num_labels = config.num_labels
@@ -1915,102 +1719,6 @@ class LSGBartForQuestionAnswering(LSGBartPretrainedModel):
1915
 
1916
  self.model._init_weights(self.qa_outputs)
1917
 
1918
- def forward(
1919
- self,
1920
- input_ids=None,
1921
- attention_mask=None,
1922
- decoder_input_ids=None,
1923
- decoder_attention_mask=None,
1924
- head_mask=None,
1925
- decoder_head_mask=None,
1926
- cross_attn_head_mask=None,
1927
- encoder_outputs=None,
1928
- start_positions=None,
1929
- end_positions=None,
1930
- inputs_embeds=None,
1931
- decoder_inputs_embeds=None,
1932
- use_cache=None,
1933
- output_attentions=None,
1934
- output_hidden_states=None,
1935
- return_dict=None,
1936
- ):
1937
-
1938
- r"""
1939
- start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1940
- Labels for position (index) of the start of the labelled span for computing the token classification loss.
1941
- Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1942
- are not taken into account for computing the loss.
1943
- end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1944
- Labels for position (index) of the end of the labelled span for computing the token classification loss.
1945
- Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1946
- are not taken into account for computing the loss.
1947
- """
1948
-
1949
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1950
- if start_positions is not None and end_positions is not None:
1951
- use_cache = False
1952
-
1953
- outputs = self.model(
1954
- input_ids,
1955
- attention_mask=attention_mask,
1956
- decoder_input_ids=decoder_input_ids,
1957
- decoder_attention_mask=decoder_attention_mask,
1958
- head_mask=head_mask,
1959
- decoder_head_mask=decoder_head_mask,
1960
- cross_attn_head_mask=cross_attn_head_mask,
1961
- encoder_outputs=encoder_outputs,
1962
- inputs_embeds=inputs_embeds,
1963
- decoder_inputs_embeds=decoder_inputs_embeds,
1964
- use_cache=use_cache,
1965
- output_attentions=output_attentions,
1966
- output_hidden_states=output_hidden_states,
1967
- return_dict=return_dict,
1968
- )
1969
-
1970
- sequence_output = outputs[0]
1971
-
1972
- logits = self.qa_outputs(sequence_output)
1973
- start_logits, end_logits = logits.split(1, dim=-1)
1974
- start_logits = start_logits.squeeze(-1).contiguous()
1975
- end_logits = end_logits.squeeze(-1).contiguous()
1976
-
1977
- total_loss = None
1978
- if start_positions is not None and end_positions is not None:
1979
- # If we are on multi-GPU, split add a dimension
1980
- if len(start_positions.size()) > 1:
1981
- start_positions = start_positions.squeeze(-1)
1982
- if len(end_positions.size()) > 1:
1983
- end_positions = end_positions.squeeze(-1)
1984
- # sometimes the start/end positions are outside our model inputs, we ignore these terms
1985
- ignored_index = start_logits.size(1)
1986
- start_positions = start_positions.clamp(0, ignored_index)
1987
- end_positions = end_positions.clamp(0, ignored_index)
1988
-
1989
- loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1990
- start_loss = loss_fct(start_logits, start_positions)
1991
- end_loss = loss_fct(end_logits, end_positions)
1992
- total_loss = (start_loss + end_loss) / 2
1993
-
1994
- if not return_dict:
1995
- output = (
1996
- start_logits,
1997
- end_logits,
1998
- ) + outputs[1:]
1999
- return ((total_loss,) + output) if total_loss is not None else output
2000
-
2001
- return Seq2SeqQuestionAnsweringModelOutput(
2002
- loss=total_loss,
2003
- start_logits=start_logits,
2004
- end_logits=end_logits,
2005
- past_key_values=outputs.past_key_values,
2006
- decoder_hidden_states=outputs.decoder_hidden_states,
2007
- decoder_attentions=outputs.decoder_attentions,
2008
- cross_attentions=outputs.cross_attentions,
2009
- encoder_last_hidden_state=outputs.encoder_last_hidden_state,
2010
- encoder_hidden_states=outputs.encoder_hidden_states,
2011
- encoder_attentions=outputs.encoder_attentions,
2012
- )
2013
-
2014
 
2015
  class LSGBartDecoderWrapper(LSGBartPretrainedModel):
2016
  """
@@ -2018,7 +1726,7 @@ class LSGBartDecoderWrapper(LSGBartPretrainedModel):
2018
  used in combination with the :class:`~transformers.EncoderDecoderModel` framework.
2019
  """
2020
 
2021
- def __init__(self, config):
2022
  super().__init__(config)
2023
  self.decoder = LSGBartDecoder(config)
2024
 
@@ -2026,14 +1734,14 @@ class LSGBartDecoderWrapper(LSGBartPretrainedModel):
2026
  return self.decoder(*args, **kwargs)
2027
 
2028
 
2029
- class LSGBartForCausalLM(LSGBartPretrainedModel):
2030
 
2031
- def __init__(self, config):
2032
 
2033
- super().__init__(config)
2034
  config = copy.deepcopy(config)
2035
  config.is_decoder = True
2036
  config.is_encoder_decoder = False
 
2037
  self.model = LSGBartDecoderWrapper(config)
2038
 
2039
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
@@ -2041,105 +1749,6 @@ class LSGBartForCausalLM(LSGBartPretrainedModel):
2041
  # Initialize weights and apply final processing
2042
  self.post_init()
2043
 
2044
- def get_input_embeddings(self):
2045
- return self.model.decoder.embed_tokens
2046
-
2047
- def set_input_embeddings(self, value):
2048
- self.model.decoder.embed_tokens = value
2049
-
2050
- def get_output_embeddings(self):
2051
- return self.lm_head
2052
-
2053
- def set_output_embeddings(self, new_embeddings):
2054
- self.lm_head = new_embeddings
2055
-
2056
- def set_decoder(self, decoder):
2057
- self.model.decoder = decoder
2058
-
2059
- def get_decoder(self):
2060
- return self.model.decoder
2061
-
2062
- def forward(
2063
- self,
2064
- input_ids=None,
2065
- attention_mask=None,
2066
- encoder_hidden_states=None,
2067
- encoder_attention_mask=None,
2068
- head_mask=None,
2069
- cross_attn_head_mask=None,
2070
- past_key_values=None,
2071
- inputs_embeds=None,
2072
- labels=None,
2073
- use_cache=None,
2074
- output_attentions=None,
2075
- output_hidden_states=None,
2076
- return_dict=None,
2077
- ):
2078
-
2079
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
2080
- output_hidden_states = (
2081
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
2082
- )
2083
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
2084
-
2085
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
2086
- outputs = self.model.decoder(
2087
- input_ids=input_ids,
2088
- attention_mask=attention_mask,
2089
- encoder_hidden_states=encoder_hidden_states,
2090
- encoder_attention_mask=encoder_attention_mask,
2091
- head_mask=head_mask,
2092
- cross_attn_head_mask=cross_attn_head_mask,
2093
- past_key_values=past_key_values,
2094
- inputs_embeds=inputs_embeds,
2095
- use_cache=use_cache,
2096
- output_attentions=output_attentions,
2097
- output_hidden_states=output_hidden_states,
2098
- return_dict=return_dict,
2099
- )
2100
-
2101
- logits = self.lm_head(outputs[0])
2102
-
2103
- loss = None
2104
- if labels is not None:
2105
- loss_fct = CrossEntropyLoss()
2106
- loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
2107
-
2108
- if not return_dict:
2109
- output = (logits,) + outputs[1:]
2110
- return (loss,) + output if loss is not None else output
2111
-
2112
- return CausalLMOutputWithCrossAttentions(
2113
- loss=loss,
2114
- logits=logits,
2115
- past_key_values=outputs.past_key_values,
2116
- hidden_states=outputs.hidden_states,
2117
- attentions=outputs.attentions,
2118
- cross_attentions=outputs.cross_attentions,
2119
- )
2120
-
2121
- def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
2122
- # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
2123
- if attention_mask is None:
2124
- attention_mask = input_ids.new_ones(input_ids.shape)
2125
-
2126
- if past:
2127
- input_ids = input_ids[:, -1:]
2128
- # first step, decoder_cached_states are empty
2129
- return {
2130
- "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
2131
- "attention_mask": attention_mask,
2132
- "past_key_values": past,
2133
- "use_cache": use_cache,
2134
- }
2135
-
2136
- @staticmethod
2137
- def _reorder_cache(past, beam_idx):
2138
- reordered_past = ()
2139
- for layer_past in past:
2140
- reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
2141
- return reordered_past
2142
-
2143
 
2144
  def str_to_class(classname):
2145
  return getattr(sys.modules[__name__], classname)
 
41
  ):
42
  """Constructs LSGConfig."""
43
  super().__init__(**kwargs)
 
 
44
 
45
  self.adaptive = adaptive
46
  self.auto_map = AUTO_MAP
 
53
  self.sparse_block_size = sparse_block_size
54
  self.sparsity_factor = sparsity_factor
55
  self.sparsity_type = sparsity_type
56
+
57
+ if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride"]:
58
+ logger.warning(
59
+ "[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride'], setting sparsity_type=None, computation will skip sparse attention")
60
+ self.sparsity_type = None
61
+
62
+ if self.sparsity_type in ["stride", "block_stride"]:
63
+ if self.sparsity_factor > self.encoder_attention_heads:
64
+ logger.warning(
65
+ "[WARNING CONFIG]: sparsity_factor > encoder_attention_heads is not recommended for stride/block_stride sparsity"
66
+ )
67
 
68
+ if self.num_global_tokens < 1:
69
+ logger.warning(
70
+ "[WARNING CONFIG]: num_global_tokens < 1 is not compatible, setting num_global_tokens=1"
71
+ )
72
+ self.num_global_tokens = 1
73
+ elif self.num_global_tokens > 512:
74
+ logger.warning(
75
+ "[WARNING CONFIG]: num_global_tokens > 512 is not compatible, setting num_global_tokens=512"
76
+ )
77
+ self.num_global_tokens = 512
78
+
79
+ if self.sparsity_factor > 0:
80
+ assert self.block_size % self.sparsity_factor == 0, "[ERROR CONFIG]: block_size must be divisible by sparsity_factor"
81
+ assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
82
+
83
 
84
  def shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id):
85
  """
 
232
  # Shape of blocks
233
  self.local_shapes = (self.block_size*3, self.block_size)
234
  if self.sparse_block_size and self.sparsity_factor > 0:
 
 
235
  self.sparse_shapes = (self.sparse_block_size*3, self.block_size//self.sparsity_factor)
236
 
237
  self.attention = BaseAttentionProduct(config)
 
412
  "pooling": self.get_sparse_tokens_with_pooling,
413
  "lsh": self.get_sparse_tokens_with_lsh,
414
  "stride": self.get_sparse_tokens_with_stride,
415
+ "block_stride": self.get_sparse_tokens_with_block_stride,
416
  }
417
 
418
  self.sparsity_type = config.sparsity_type
419
+ self.get_sparse_elements = sparse_functions.get(self.sparsity_type, lambda x, y, z: (None, None, None))
 
 
 
 
 
 
420
 
421
  if config.sparsity_type == "lsh":
422
  self.lsh_num_pre_rounds = config.lsh_num_pre_rounds
423
+
424
  def get_sparse_tokens_with_norm(self, keys, values, mask):
425
 
426
  if self.sparsity_factor == 1:
427
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
428
 
429
  with torch.no_grad():
430
 
 
452
  def get_sparse_tokens_with_pooling(self, keys, values, mask):
453
 
454
  if self.sparsity_factor == 1:
455
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
456
 
457
  keys = self.chunk(keys, self.sparsity_factor)
458
  values = self.chunk(values, self.sparsity_factor)
 
474
  def get_sparse_tokens_with_stride(self, keys, values, mask):
475
 
476
  if self.sparsity_factor == 1:
477
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
478
 
479
  n, h, t, d = keys.size()
480
  sparse_idx = torch.arange(t // self.sparsity_factor, device=keys.device) * self.sparsity_factor
 
487
 
488
  return keys, values, mask
489
 
490
+ def get_sparse_tokens_with_block_stride(self, keys, values, mask):
491
+
492
+ if self.sparsity_factor == 1:
493
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
494
+
495
+ n, h, t, d = keys.size()
496
+
497
+ t, b = self.block_size, t // self.block_size
498
+ sparse_idx = torch.arange(t // self.sparsity_factor, device=keys.device)
499
+ sparse_idx = sparse_idx.reshape(1, 1, 1, -1, 1) + torch.arange(h, device=keys.device).reshape(1, h, 1, 1, 1) * (t // self.sparsity_factor)
500
+ sparse_idx = (sparse_idx % t)
501
+ sparse_idx = sparse_idx + torch.arange(b, device=keys.device).reshape(1, 1, -1, 1, 1) * t
502
+ sparse_idx = sparse_idx.reshape(1, h, -1, 1).expand(n, h, -1, 1)
503
+
504
+ keys = keys.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
505
+ values = values.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
506
+ mask = mask.expand(-1, h, -1, -1).transpose(-1, -2).gather(dim=-2, index=sparse_idx).transpose(-1, -2)
507
+
508
+ return keys, values, mask
509
+
510
  def get_sparse_tokens_with_lsh(self, keys, values, mask):
511
 
512
  if self.sparsity_factor == 1:
513
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
514
 
515
  block_size = min(self.block_size, self.sparse_block_size)
516
  keys = self.chunk(keys, block_size)
 
527
  extra_factor = 1
528
 
529
  for _ in range(self.lsh_num_pre_rounds):
530
+ keys, values, mask = self.lsh_round(keys, values, mask, t*extra_factor)
531
 
532
+ keys, values, mask = self.lsh_round(keys, values, mask, t//self.sparsity_factor)
533
  keys /= mask + 1e-8
534
  values /= mask + 1e-8
535
 
 
537
 
538
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
539
 
540
+ def lsh_round(self, keys, values, mask, output_size):
541
 
542
  with torch.no_grad():
543
 
 
1341
  self.padding_idx = config.pad_token_id
1342
  self.max_target_positions = config.max_position_embeddings
1343
  self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
1344
+ self.adaptive = config.adaptive
1345
 
1346
  if embed_tokens is not None:
1347
  self.embed_tokens = embed_tokens
 
1384
 
1385
  return combined_attention_mask
1386
 
1387
+ def resize_inputs(self, inputs_embeds, attention_mask):
1388
+ pad = 0
1389
+
1390
+ max_len = int(attention_mask.sum(dim=-1).max())
1391
+ pad = attention_mask.size()[-1] - max_len
1392
+ inputs_embeds = inputs_embeds[:, :max_len]
1393
+ attention_mask = attention_mask[..., :max_len]
1394
+ return pad, inputs_embeds, attention_mask
1395
+
1396
  def forward(
1397
  self,
1398
  input_ids=None,
 
1433
  if inputs_embeds is None:
1434
  inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
1435
 
1436
+ # Resize to reduce computation
1437
+ pad = 0
1438
+ if self.adaptive:
1439
+ if attention_mask is not None:
1440
+ pad, inputs_embeds, attention_mask = self.resize_inputs(inputs_embeds, attention_mask)
1441
+ input_shape = inputs_embeds.size()[:-1]
1442
+ if encoder_attention_mask is not None:
1443
+ _, encoder_hidden_states, encoder_attention_mask = self.resize_inputs(encoder_hidden_states, encoder_attention_mask)
1444
 
1445
  attention_mask = self._prepare_decoder_attention_mask(
1446
  attention_mask, input_shape, inputs_embeds, past_key_values_length
 
1534
  if encoder_hidden_states is not None:
1535
  all_cross_attentions += (layer_outputs[2],)
1536
 
1537
+ # Resize to original shape
1538
+ hidden_states = torch.nn.functional.pad(hidden_states.transpose(-1, -2), pad=(0, pad), value=0).transpose(-1, -2)
1539
+
1540
  # add hidden states from the last decoder layer
1541
  if output_hidden_states:
1542
  all_hidden_states += (hidden_states,)
 
1673
  )
1674
 
1675
 
1676
+ class LSGBartForConditionalGeneration(BartForConditionalGeneration, LSGBartPretrainedModel):
1677
 
1678
  base_model_prefix = "model"
1679
  _keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head\.weight"]
1680
 
1681
  def __init__(self, config):
1682
 
1683
+ LSGBartPretrainedModel.__init__(self, config)
1684
  self.model = LSGBartModel(config)
1685
  self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
1686
  self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
 
1688
  # Initialize weights and apply final processing
1689
  self.post_init()
1690
 
 
 
 
 
 
1691
 
1692
+ class LSGBartForSequenceClassification(BartForSequenceClassification, LSGBartPretrainedModel):
 
 
 
1693
 
1694
+ def __init__(self, config: LSGBartConfig, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1695
 
1696
+ LSGBartPretrainedModel.__init__(self, config, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1697
  self.model = LSGBartModel(config)
1698
  self.classification_head = LSGBartClassificationHead(
1699
  config.d_model,
 
1704
  self.model._init_weights(self.classification_head.dense)
1705
  self.model._init_weights(self.classification_head.out_proj)
1706
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1707
 
1708
+ class LSGBartForQuestionAnswering(BartForQuestionAnswering, LSGBartPretrainedModel):
 
 
1709
 
1710
+ def __init__(self, config: LSGBartConfig):
 
 
 
1711
 
1712
+ LSGBartPretrainedModel.__init__(self, config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1713
 
1714
  config.num_labels = 2
1715
  self.num_labels = config.num_labels
 
1719
 
1720
  self.model._init_weights(self.qa_outputs)
1721
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1722
 
1723
  class LSGBartDecoderWrapper(LSGBartPretrainedModel):
1724
  """
 
1726
  used in combination with the :class:`~transformers.EncoderDecoderModel` framework.
1727
  """
1728
 
1729
+ def __init__(self, config: LSGBartConfig):
1730
  super().__init__(config)
1731
  self.decoder = LSGBartDecoder(config)
1732
 
 
1734
  return self.decoder(*args, **kwargs)
1735
 
1736
 
1737
+ class LSGBartForCausalLM(BartForCausalLM, LSGBartPretrainedModel):
1738
 
1739
+ def __init__(self, config: LSGBartConfig):
1740
 
 
1741
  config = copy.deepcopy(config)
1742
  config.is_decoder = True
1743
  config.is_encoder_decoder = False
1744
+ LSGBartPretrainedModel.__init__(self, config)
1745
  self.model = LSGBartDecoderWrapper(config)
1746
 
1747
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 
1749
  # Initialize weights and apply final processing
1750
  self.post_init()
1751
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1752
 
1753
  def str_to_class(classname):
1754
  return getattr(sys.modules[__name__], classname)
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:50d0ecc7c34b142bf2c8ff8485397795187896aaaf99ce996b716877f1886a68
3
  size 653914167
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b9b2d1ad40fbb543df25bbb25c465dfd0f87a0399f6b2d6bbed13b569d9345e0
3
  size 653914167
tokenizer.json CHANGED
@@ -6,16 +6,7 @@
6
  "strategy": "LongestFirst",
7
  "stride": 0
8
  },
9
- "padding": {
10
- "strategy": {
11
- "Fixed": 512
12
- },
13
- "direction": "Right",
14
- "pad_to_multiple_of": null,
15
- "pad_id": 1,
16
- "pad_type_id": 0,
17
- "pad_token": "<pad>"
18
- },
19
  "added_tokens": [
20
  {
21
  "id": 0,
 
6
  "strategy": "LongestFirst",
7
  "stride": 0
8
  },
9
+ "padding": null,
 
 
 
 
 
 
 
 
 
10
  "added_tokens": [
11
  {
12
  "id": 0,
tokenizer_config.json CHANGED
@@ -1 +1 @@
1
- {"errors": "replace", "bos_token": "<s>", "eos_token": "</s>", "sep_token": "</s>", "cls_token": "<s>", "unk_token": "<unk>", "pad_token": "<pad>", "mask_token": "<mask>", "add_prefix_space": false, "trim_offsets": true, "model_max_length": 16384, "special_tokens_map_file": null, "name_or_path": "tmp/pubmed/lsg_local_large_lr_16384_full_trained", "tokenizer_class": "BartTokenizer"}
 
1
+ {"errors": "replace", "bos_token": "<s>", "eos_token": "</s>", "sep_token": "</s>", "cls_token": "<s>", "unk_token": "<unk>", "pad_token": "<pad>", "mask_token": "<mask>", "add_prefix_space": false, "trim_offsets": true, "model_max_length": 16384, "special_tokens_map_file": null, "name_or_path": "/data/ccondevaux/lsg/text-summarization/tmp_final/pubmed/lsg_local_16384_trained", "tokenizer_class": "BartTokenizer"}