ccdv commited on
Commit
d910415
1 Parent(s): 85efa08
Files changed (4) hide show
  1. README.md +21 -13
  2. config.json +3 -3
  3. modeling_lsg_bart.py +85 -480
  4. pytorch_model.bin +1 -1
README.md CHANGED
@@ -23,16 +23,23 @@ should probably proofread and complete it, then remove this comment. -->
23
  This model is a fine-tuned version of [ccdv/lsg-bart-base-4096](https://huggingface.co/ccdv/lsg-bart-base-4096) on the scientific_papers arxiv dataset. \
24
  It achieves the following results on the test set:
25
 
26
- | Length | Sparse Type | Block Size | Sparsity | Connexions | R1 | R2 | RL | RLsum |
27
- |:------ |:----------- |:---------- |:-------- | :--------- |:----- |:----- |:----- |:----- |
28
- | 4096 | - | 256 | 0 | 768 | 46.29 | 18.71 | 26.77 | 41.85 |
29
- | 4096 | - | 128 | 0 | 384 | 45.87 | 18.44 | 26.66 | 41.42 |
30
- | 4096 | Stride | 128 | 4 | 644 | 46.07 | 18.51 | 26.61 | 41.58 |
31
- | 4096 | Pooling | 128 | 4 | 644 | 46.02 | 18.52 | 26.73 | 41.55 |
32
- | 4096 | LSH | 128 | 4 | 644 | 45.78 | 18.48 | 26.70 | 41.40 |
33
- | 4096 | Norm | 128 | 4 | 644 | 45.76 | 18.26 | 26.36 | 41.26 |
34
-
35
-
 
 
 
 
 
 
 
36
 
37
  ## Model description
38
  The model relies on Local-Sparse-Global attention to handle long sequences:
@@ -61,7 +68,8 @@ The following hyperparameters were used during training:
61
  - total_train_batch_size: 32
62
  - optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
63
  - lr_scheduler_type: linear
64
- - num_epochs: 5.0
 
65
 
66
  ### Generate hyperparameters
67
 
@@ -69,13 +77,13 @@ The following hyperparameters were used during generation:
69
  - dataset_name: scientific_papers
70
  - dataset_config_name: arxiv
71
  - eval_batch_size: 8
 
72
  - early_stopping: True
73
  - ignore_pad_token_for_loss: True
74
  - length_penalty: 2.0
75
  - max_length: 320
76
- - min_length: 64
77
  - num_beams: 5
78
- - num_samples: None
79
  - no_repeat_ngram_size: None
80
  - seed: 123
81
 
 
23
  This model is a fine-tuned version of [ccdv/lsg-bart-base-4096](https://huggingface.co/ccdv/lsg-bart-base-4096) on the scientific_papers arxiv dataset. \
24
  It achieves the following results on the test set:
25
 
26
+ | Length | Sparse Type | Block Size | Sparsity | Connexions | R1 | R2 | RL | RLsum |
27
+ |:------ |:------------ |:---------- |:-------- | :--------- |:----- |:----- |:----- |:----- |
28
+ | 4096 | Local | 256 | 0 | 768 | 46.65 | 18.91 | 26.90 | 42.18 |
29
+ | 4096 | Local | 128 | 0 | 384 | 46.18 | 18.57 | 26.71 | 41.69 |
30
+ | 4096 | Pooling | 128 | 4 | 644 | 46.27 | 18.68 | 26.87 | 41.82 |
31
+ | 4096 | Stride | 128 | 4 | 644 | 46.34 | 18.64 | 26.69 | 41.87 |
32
+ | 4096 | Norm | 128 | 4 | 644 | 45.96 | 18.46 | 26.52 | 41.51 |
33
+ | 4096 | LSH | 128 | 4 | 644 | 46.19 | 18.72 | 26.89 | 41.76 |
34
+
35
+ With blocks of size 32 (lower ressources):
36
+ | Length | Sparse Type | Block Size | Sparsity | Connexions | R1 | R2 | RL | RLsum |
37
+ |:------ |:------------ |:---------- |:-------- | :--------- |:----- |:----- |:----- |:----- |
38
+ | 4096 | Pooling | 32 | 4 | 160 | 42.75 | 16.34 | 25.20 | 38.23 |
39
+ | 4096 | Stride | 32 | 4 | 160 | 44.23 | 17.21 | 25.71 | 39.72 |
40
+ | 4096 | Block Stride | 32 | 4 | 160 | 44.15 | 17.10 | 25.68 | 39.60 |
41
+ | 4096 | Norm | 32 | 4 | 160 | 42.02 | 15.65 | 24.56 | 37.45 |
42
+ | 4096 | LSH | 32 | 4 | 160 | 42.58 | 16.21 | 25.10 | 38.04 |
43
 
44
  ## Model description
45
  The model relies on Local-Sparse-Global attention to handle long sequences:
 
68
  - total_train_batch_size: 32
69
  - optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
70
  - lr_scheduler_type: linear
71
+ - lr_scheduler_warmup_ratio: 0.1
72
+ - num_epochs: 6.0
73
 
74
  ### Generate hyperparameters
75
 
 
77
  - dataset_name: scientific_papers
78
  - dataset_config_name: arxiv
79
  - eval_batch_size: 8
80
+ - eval_samples: 6440
81
  - early_stopping: True
82
  - ignore_pad_token_for_loss: True
83
  - length_penalty: 2.0
84
  - max_length: 320
85
+ - min_length: 32
86
  - num_beams: 5
 
87
  - no_repeat_ngram_size: None
88
  - seed: 123
89
 
config.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
- "_name_or_path": "ccdv/lsg-bart-base-4096-arxiv",
3
  "activation_dropout": 0.1,
4
  "activation_function": "gelu",
5
  "adaptive": true,
@@ -67,8 +67,8 @@
67
  "pool_with_global": true,
68
  "scale_embedding": false,
69
  "sparse_block_size": 0,
70
- "sparsity_factor": 4,
71
- "sparsity_type": "norm",
72
  "task_specific_params": {
73
  "summarization": {
74
  "length_penalty": 1.0,
 
1
  {
2
+ "_name_or_path": "/data/ccondevaux/lsg/text-summarization/tmp_final/arxiv/lsg_local",
3
  "activation_dropout": 0.1,
4
  "activation_function": "gelu",
5
  "adaptive": true,
 
67
  "pool_with_global": true,
68
  "scale_embedding": false,
69
  "sparse_block_size": 0,
70
+ "sparsity_factor": 2,
71
+ "sparsity_type": "none",
72
  "task_specific_params": {
73
  "summarization": {
74
  "length_penalty": 1.0,
modeling_lsg_bart.py CHANGED
@@ -41,8 +41,6 @@ class LSGBartConfig(BartConfig):
41
  ):
42
  """Constructs LSGConfig."""
43
  super().__init__(**kwargs)
44
-
45
- assert sparsity_type in ["norm", "lsh", "pooling", "stride"], "Sparsity mode must be 'norm', 'lsh' or 'pooling'"
46
 
47
  self.adaptive = adaptive
48
  self.auto_map = AUTO_MAP
@@ -55,7 +53,33 @@ class LSGBartConfig(BartConfig):
55
  self.sparse_block_size = sparse_block_size
56
  self.sparsity_factor = sparsity_factor
57
  self.sparsity_type = sparsity_type
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
 
 
 
 
59
 
60
  def shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id):
61
  """
@@ -208,8 +232,6 @@ class LSGAttentionProduct(nn.Module):
208
  # Shape of blocks
209
  self.local_shapes = (self.block_size*3, self.block_size)
210
  if self.sparse_block_size and self.sparsity_factor > 0:
211
- assert self.block_size % self.sparsity_factor == 0, "block_size must be divisible by sparsity_factor"
212
- assert self.block_size//self.sparsity_factor >= 1, "Config is wrong, make sure block_size >= sparsity_factor"
213
  self.sparse_shapes = (self.sparse_block_size*3, self.block_size//self.sparsity_factor)
214
 
215
  self.attention = BaseAttentionProduct(config)
@@ -393,21 +415,15 @@ class LSGBartEncoderAttention(BaseSelfAttention):
393
  }
394
 
395
  self.sparsity_type = config.sparsity_type
396
- self.get_sparse_elements = sparse_functions[self.sparsity_type]
397
-
398
- if config.sparsity_type == "stride":
399
- if config.sparsity_factor > config.encoder_attention_heads:
400
- logger.warning(
401
- "Warning: sparsity_factor > encoder_attention_heads is not recommended for stride sparsity"
402
- )
403
 
404
  if config.sparsity_type == "lsh":
405
  self.lsh_num_pre_rounds = config.lsh_num_pre_rounds
406
-
407
  def get_sparse_tokens_with_norm(self, keys, values, mask):
408
 
409
  if self.sparsity_factor == 1:
410
- return keys, values, mask
411
 
412
  with torch.no_grad():
413
 
@@ -435,7 +451,7 @@ class LSGBartEncoderAttention(BaseSelfAttention):
435
  def get_sparse_tokens_with_pooling(self, keys, values, mask):
436
 
437
  if self.sparsity_factor == 1:
438
- return keys, values, mask
439
 
440
  keys = self.chunk(keys, self.sparsity_factor)
441
  values = self.chunk(values, self.sparsity_factor)
@@ -457,13 +473,30 @@ class LSGBartEncoderAttention(BaseSelfAttention):
457
  def get_sparse_tokens_with_stride(self, keys, values, mask):
458
 
459
  if self.sparsity_factor == 1:
460
- return keys, values, mask
461
 
462
  n, h, t, d = keys.size()
463
  sparse_idx = torch.arange(t // self.sparsity_factor, device=keys.device) * self.sparsity_factor
464
  sparse_idx = sparse_idx.reshape(1, 1, -1, 1) + (torch.arange(h, device=keys.device) % self.sparsity_factor).reshape(1, h, 1, 1)
465
  sparse_idx = sparse_idx.expand(n, h, -1, 1)
466
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
467
  keys = keys.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
468
  values = values.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
469
  mask = mask.expand(-1, h, -1, -1).transpose(-1, -2).gather(dim=-2, index=sparse_idx).transpose(-1, -2)
@@ -473,7 +506,7 @@ class LSGBartEncoderAttention(BaseSelfAttention):
473
  def get_sparse_tokens_with_lsh(self, keys, values, mask):
474
 
475
  if self.sparsity_factor == 1:
476
- return keys, values, mask
477
 
478
  block_size = min(self.block_size, self.sparse_block_size)
479
  keys = self.chunk(keys, block_size)
@@ -490,9 +523,9 @@ class LSGBartEncoderAttention(BaseSelfAttention):
490
  extra_factor = 1
491
 
492
  for _ in range(self.lsh_num_pre_rounds):
493
- keys, values, mask = self.lsg_round(keys, values, mask, t*extra_factor)
494
 
495
- keys, values, mask = self.lsg_round(keys, values, mask, t//self.sparsity_factor)
496
  keys /= mask + 1e-8
497
  values /= mask + 1e-8
498
 
@@ -500,7 +533,7 @@ class LSGBartEncoderAttention(BaseSelfAttention):
500
 
501
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
502
 
503
- def lsg_round(self, keys, values, mask, output_size):
504
 
505
  with torch.no_grad():
506
 
@@ -1304,6 +1337,7 @@ class LSGBartDecoder(LSGBartPretrainedModel):
1304
  self.padding_idx = config.pad_token_id
1305
  self.max_target_positions = config.max_position_embeddings
1306
  self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
 
1307
 
1308
  if embed_tokens is not None:
1309
  self.embed_tokens = embed_tokens
@@ -1346,6 +1380,15 @@ class LSGBartDecoder(LSGBartPretrainedModel):
1346
 
1347
  return combined_attention_mask
1348
 
 
 
 
 
 
 
 
 
 
1349
  def forward(
1350
  self,
1351
  input_ids=None,
@@ -1386,12 +1429,14 @@ class LSGBartDecoder(LSGBartPretrainedModel):
1386
  if inputs_embeds is None:
1387
  inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
1388
 
1389
- # Cut
1390
- if attention_mask is not None:
1391
- max_len = int(attention_mask.sum(dim=-1).max())
1392
- inputs_embeds = inputs_embeds[:, :max_len]
1393
- attention_mask = attention_mask[..., :max_len]
1394
- input_shape = inputs_embeds.size()[:-1]
 
 
1395
 
1396
  attention_mask = self._prepare_decoder_attention_mask(
1397
  attention_mask, input_shape, inputs_embeds, past_key_values_length
@@ -1485,6 +1530,9 @@ class LSGBartDecoder(LSGBartPretrainedModel):
1485
  if encoder_hidden_states is not None:
1486
  all_cross_attentions += (layer_outputs[2],)
1487
 
 
 
 
1488
  # add hidden states from the last decoder layer
1489
  if output_hidden_states:
1490
  all_hidden_states += (hidden_states,)
@@ -1621,14 +1669,14 @@ class LSGBartModel(LSGBartPretrainedModel):
1621
  )
1622
 
1623
 
1624
- class LSGBartForConditionalGeneration(LSGBartPretrainedModel):
1625
 
1626
  base_model_prefix = "model"
1627
  _keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head\.weight"]
1628
 
1629
  def __init__(self, config):
1630
 
1631
- super().__init__(config)
1632
  self.model = LSGBartModel(config)
1633
  self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
1634
  self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
@@ -1636,157 +1684,12 @@ class LSGBartForConditionalGeneration(LSGBartPretrainedModel):
1636
  # Initialize weights and apply final processing
1637
  self.post_init()
1638
 
1639
- def get_encoder(self):
1640
- return self.model.get_encoder()
1641
-
1642
- def get_decoder(self):
1643
- return self.model.get_decoder()
1644
-
1645
- def resize_token_embeddings(self, new_num_tokens):
1646
- new_embeddings = super().resize_token_embeddings(new_num_tokens)
1647
- self._resize_final_logits_bias(new_num_tokens)
1648
- return new_embeddings
1649
-
1650
- def _resize_final_logits_bias(self, new_num_tokens):
1651
- old_num_tokens = self.final_logits_bias.shape[-1]
1652
- if new_num_tokens <= old_num_tokens:
1653
- new_bias = self.final_logits_bias[:, :new_num_tokens]
1654
- else:
1655
- extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
1656
- new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
1657
- self.register_buffer("final_logits_bias", new_bias)
1658
-
1659
- def get_output_embeddings(self):
1660
- return self.lm_head
1661
-
1662
- def set_output_embeddings(self, new_embeddings):
1663
- self.lm_head = new_embeddings
1664
-
1665
- def forward(
1666
- self,
1667
- input_ids=None,
1668
- attention_mask=None,
1669
- decoder_input_ids=None,
1670
- decoder_attention_mask=None,
1671
- head_mask=None,
1672
- decoder_head_mask=None,
1673
- cross_attn_head_mask=None,
1674
- encoder_outputs=None,
1675
- past_key_values=None,
1676
- inputs_embeds=None,
1677
- decoder_inputs_embeds=None,
1678
- labels=None,
1679
- use_cache=None,
1680
- output_attentions=None,
1681
- output_hidden_states=None,
1682
- return_dict=None,
1683
- ):
1684
-
1685
- r"""
1686
- labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1687
- Labels for computing the masked language modeling loss. Indices should either be in ``[0, ...,
1688
- config.vocab_size]`` or -100 (see ``input_ids`` docstring). Tokens with indices set to ``-100`` are ignored
1689
- (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``.
1690
- Returns:
1691
- """
1692
-
1693
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1694
-
1695
- if labels is not None:
1696
- if decoder_input_ids is None and decoder_inputs_embeds is None:
1697
- decoder_input_ids = shift_tokens_right(
1698
- labels, self.config.pad_token_id, self.config.decoder_start_token_id
1699
- )
1700
-
1701
- outputs = self.model(
1702
- input_ids,
1703
- attention_mask=attention_mask,
1704
- decoder_input_ids=decoder_input_ids,
1705
- encoder_outputs=encoder_outputs,
1706
- decoder_attention_mask=decoder_attention_mask,
1707
- head_mask=head_mask,
1708
- decoder_head_mask=decoder_head_mask,
1709
- cross_attn_head_mask=cross_attn_head_mask,
1710
- past_key_values=past_key_values,
1711
- inputs_embeds=inputs_embeds,
1712
- decoder_inputs_embeds=decoder_inputs_embeds,
1713
- use_cache=use_cache,
1714
- output_attentions=output_attentions,
1715
- output_hidden_states=output_hidden_states,
1716
- return_dict=return_dict,
1717
- )
1718
-
1719
-
1720
- lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
1721
-
1722
- masked_lm_loss = None
1723
- if labels is not None:
1724
- loss_fct = CrossEntropyLoss()
1725
- masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
1726
-
1727
- if not return_dict:
1728
- output = (lm_logits,) + outputs[1:]
1729
- return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1730
-
1731
- return Seq2SeqLMOutput(
1732
- loss=masked_lm_loss,
1733
- logits=lm_logits,
1734
- past_key_values=outputs.past_key_values,
1735
- decoder_hidden_states=outputs.decoder_hidden_states,
1736
- decoder_attentions=outputs.decoder_attentions,
1737
- cross_attentions=outputs.cross_attentions,
1738
- encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1739
- encoder_hidden_states=outputs.encoder_hidden_states,
1740
- encoder_attentions=outputs.encoder_attentions,
1741
- )
1742
-
1743
- def prepare_inputs_for_generation(
1744
- self,
1745
- decoder_input_ids,
1746
- past=None,
1747
- attention_mask=None,
1748
- head_mask=None,
1749
- decoder_head_mask=None,
1750
- cross_attn_head_mask=None,
1751
- use_cache=None,
1752
- encoder_outputs=None,
1753
- **kwargs
1754
- ):
1755
- # cut decoder_input_ids if past is used
1756
- if past is not None:
1757
- decoder_input_ids = decoder_input_ids[:, -1:]
1758
-
1759
- return {
1760
- "input_ids": None, # encoder_outputs is defined. input_ids not needed
1761
- "encoder_outputs": encoder_outputs,
1762
- "past_key_values": past,
1763
- "decoder_input_ids": decoder_input_ids,
1764
- "attention_mask": attention_mask,
1765
- "head_mask": head_mask,
1766
- "decoder_head_mask": decoder_head_mask,
1767
- "cross_attn_head_mask": cross_attn_head_mask,
1768
- "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
1769
- }
1770
-
1771
- def prepare_decoder_input_ids_from_labels(self, labels):
1772
- return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
1773
-
1774
- @staticmethod
1775
- def _reorder_cache(past, beam_idx):
1776
- reordered_past = ()
1777
- for layer_past in past:
1778
- # cached cross_attention states don't have to be reordered -> they are always the same
1779
- reordered_past += (
1780
- tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
1781
- )
1782
- return reordered_past
1783
-
1784
 
1785
- class LSGBartForSequenceClassification(LSGBartPretrainedModel):
1786
 
1787
- def __init__(self, config, **kwargs):
1788
 
1789
- super().__init__(config, **kwargs)
1790
  self.model = LSGBartModel(config)
1791
  self.classification_head = LSGBartClassificationHead(
1792
  config.d_model,
@@ -1797,115 +1700,12 @@ class LSGBartForSequenceClassification(LSGBartPretrainedModel):
1797
  self.model._init_weights(self.classification_head.dense)
1798
  self.model._init_weights(self.classification_head.out_proj)
1799
 
1800
- def forward(
1801
- self,
1802
- input_ids=None,
1803
- attention_mask=None,
1804
- decoder_input_ids=None,
1805
- decoder_attention_mask=None,
1806
- head_mask=None,
1807
- decoder_head_mask=None,
1808
- cross_attn_head_mask=None,
1809
- encoder_outputs=None,
1810
- inputs_embeds=None,
1811
- decoder_inputs_embeds=None,
1812
- labels=None,
1813
- use_cache=None,
1814
- output_attentions=None,
1815
- output_hidden_states=None,
1816
- return_dict=None,
1817
- ):
1818
-
1819
- r"""
1820
- labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1821
- Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
1822
- config.num_labels - 1]`. If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1823
- """
1824
-
1825
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1826
- if labels is not None:
1827
- use_cache = False
1828
-
1829
- if input_ids is None and inputs_embeds is not None:
1830
- raise NotImplementedError(
1831
- f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
1832
- )
1833
-
1834
- outputs = self.model(
1835
- input_ids,
1836
- attention_mask=attention_mask,
1837
- decoder_input_ids=decoder_input_ids,
1838
- decoder_attention_mask=decoder_attention_mask,
1839
- head_mask=head_mask,
1840
- decoder_head_mask=decoder_head_mask,
1841
- cross_attn_head_mask=cross_attn_head_mask,
1842
- encoder_outputs=encoder_outputs,
1843
- inputs_embeds=inputs_embeds,
1844
- decoder_inputs_embeds=decoder_inputs_embeds,
1845
- use_cache=use_cache,
1846
- output_attentions=output_attentions,
1847
- output_hidden_states=output_hidden_states,
1848
- return_dict=return_dict,
1849
- )
1850
- hidden_states = outputs[0] # last hidden state
1851
-
1852
- eos_mask = input_ids.eq(self.config.eos_token_id)
1853
-
1854
- t, t_ = eos_mask.size()[-1], hidden_states.size()[-2]
1855
- if t > t_:
1856
- eos_mask = eos_mask[:, :t_]
1857
-
1858
- if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
1859
- raise ValueError("All examples must have the same number of <eos> tokens.")
1860
- sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
1861
- :, -1, :
1862
- ]
1863
- logits = self.classification_head(sentence_representation)
1864
-
1865
- loss = None
1866
- if labels is not None:
1867
- if self.config.problem_type is None:
1868
- if self.config.num_labels == 1:
1869
- self.config.problem_type = "regression"
1870
- elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1871
- self.config.problem_type = "single_label_classification"
1872
- else:
1873
- self.config.problem_type = "multi_label_classification"
1874
-
1875
- if self.config.problem_type == "regression":
1876
- loss_fct = MSELoss()
1877
- if self.config.num_labels == 1:
1878
- loss = loss_fct(logits.squeeze(), labels.squeeze())
1879
- else:
1880
- loss = loss_fct(logits, labels)
1881
- elif self.config.problem_type == "single_label_classification":
1882
- loss_fct = CrossEntropyLoss()
1883
- loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
1884
- elif self.config.problem_type == "multi_label_classification":
1885
- loss_fct = BCEWithLogitsLoss()
1886
- loss = loss_fct(logits, labels)
1887
- if not return_dict:
1888
- output = (logits,) + outputs[1:]
1889
- return ((loss,) + output) if loss is not None else output
1890
-
1891
- return Seq2SeqSequenceClassifierOutput(
1892
- loss=loss,
1893
- logits=logits,
1894
- past_key_values=outputs.past_key_values,
1895
- decoder_hidden_states=outputs.decoder_hidden_states,
1896
- decoder_attentions=outputs.decoder_attentions,
1897
- cross_attentions=outputs.cross_attentions,
1898
- encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1899
- encoder_hidden_states=outputs.encoder_hidden_states,
1900
- encoder_attentions=outputs.encoder_attentions,
1901
- )
1902
 
 
1903
 
1904
- class LSGBartForQuestionAnswering(LSGBartPretrainedModel):
1905
 
1906
- def __init__(self, config):
1907
-
1908
- super().__init__(config)
1909
 
1910
  config.num_labels = 2
1911
  self.num_labels = config.num_labels
@@ -1915,102 +1715,6 @@ class LSGBartForQuestionAnswering(LSGBartPretrainedModel):
1915
 
1916
  self.model._init_weights(self.qa_outputs)
1917
 
1918
- def forward(
1919
- self,
1920
- input_ids=None,
1921
- attention_mask=None,
1922
- decoder_input_ids=None,
1923
- decoder_attention_mask=None,
1924
- head_mask=None,
1925
- decoder_head_mask=None,
1926
- cross_attn_head_mask=None,
1927
- encoder_outputs=None,
1928
- start_positions=None,
1929
- end_positions=None,
1930
- inputs_embeds=None,
1931
- decoder_inputs_embeds=None,
1932
- use_cache=None,
1933
- output_attentions=None,
1934
- output_hidden_states=None,
1935
- return_dict=None,
1936
- ):
1937
-
1938
- r"""
1939
- start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1940
- Labels for position (index) of the start of the labelled span for computing the token classification loss.
1941
- Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1942
- are not taken into account for computing the loss.
1943
- end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1944
- Labels for position (index) of the end of the labelled span for computing the token classification loss.
1945
- Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1946
- are not taken into account for computing the loss.
1947
- """
1948
-
1949
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1950
- if start_positions is not None and end_positions is not None:
1951
- use_cache = False
1952
-
1953
- outputs = self.model(
1954
- input_ids,
1955
- attention_mask=attention_mask,
1956
- decoder_input_ids=decoder_input_ids,
1957
- decoder_attention_mask=decoder_attention_mask,
1958
- head_mask=head_mask,
1959
- decoder_head_mask=decoder_head_mask,
1960
- cross_attn_head_mask=cross_attn_head_mask,
1961
- encoder_outputs=encoder_outputs,
1962
- inputs_embeds=inputs_embeds,
1963
- decoder_inputs_embeds=decoder_inputs_embeds,
1964
- use_cache=use_cache,
1965
- output_attentions=output_attentions,
1966
- output_hidden_states=output_hidden_states,
1967
- return_dict=return_dict,
1968
- )
1969
-
1970
- sequence_output = outputs[0]
1971
-
1972
- logits = self.qa_outputs(sequence_output)
1973
- start_logits, end_logits = logits.split(1, dim=-1)
1974
- start_logits = start_logits.squeeze(-1).contiguous()
1975
- end_logits = end_logits.squeeze(-1).contiguous()
1976
-
1977
- total_loss = None
1978
- if start_positions is not None and end_positions is not None:
1979
- # If we are on multi-GPU, split add a dimension
1980
- if len(start_positions.size()) > 1:
1981
- start_positions = start_positions.squeeze(-1)
1982
- if len(end_positions.size()) > 1:
1983
- end_positions = end_positions.squeeze(-1)
1984
- # sometimes the start/end positions are outside our model inputs, we ignore these terms
1985
- ignored_index = start_logits.size(1)
1986
- start_positions = start_positions.clamp(0, ignored_index)
1987
- end_positions = end_positions.clamp(0, ignored_index)
1988
-
1989
- loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1990
- start_loss = loss_fct(start_logits, start_positions)
1991
- end_loss = loss_fct(end_logits, end_positions)
1992
- total_loss = (start_loss + end_loss) / 2
1993
-
1994
- if not return_dict:
1995
- output = (
1996
- start_logits,
1997
- end_logits,
1998
- ) + outputs[1:]
1999
- return ((total_loss,) + output) if total_loss is not None else output
2000
-
2001
- return Seq2SeqQuestionAnsweringModelOutput(
2002
- loss=total_loss,
2003
- start_logits=start_logits,
2004
- end_logits=end_logits,
2005
- past_key_values=outputs.past_key_values,
2006
- decoder_hidden_states=outputs.decoder_hidden_states,
2007
- decoder_attentions=outputs.decoder_attentions,
2008
- cross_attentions=outputs.cross_attentions,
2009
- encoder_last_hidden_state=outputs.encoder_last_hidden_state,
2010
- encoder_hidden_states=outputs.encoder_hidden_states,
2011
- encoder_attentions=outputs.encoder_attentions,
2012
- )
2013
-
2014
 
2015
  class LSGBartDecoderWrapper(LSGBartPretrainedModel):
2016
  """
@@ -2018,7 +1722,7 @@ class LSGBartDecoderWrapper(LSGBartPretrainedModel):
2018
  used in combination with the :class:`~transformers.EncoderDecoderModel` framework.
2019
  """
2020
 
2021
- def __init__(self, config):
2022
  super().__init__(config)
2023
  self.decoder = LSGBartDecoder(config)
2024
 
@@ -2026,14 +1730,14 @@ class LSGBartDecoderWrapper(LSGBartPretrainedModel):
2026
  return self.decoder(*args, **kwargs)
2027
 
2028
 
2029
- class LSGBartForCausalLM(LSGBartPretrainedModel):
2030
 
2031
- def __init__(self, config):
2032
 
2033
- super().__init__(config)
2034
  config = copy.deepcopy(config)
2035
  config.is_decoder = True
2036
  config.is_encoder_decoder = False
 
2037
  self.model = LSGBartDecoderWrapper(config)
2038
 
2039
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
@@ -2041,105 +1745,6 @@ class LSGBartForCausalLM(LSGBartPretrainedModel):
2041
  # Initialize weights and apply final processing
2042
  self.post_init()
2043
 
2044
- def get_input_embeddings(self):
2045
- return self.model.decoder.embed_tokens
2046
-
2047
- def set_input_embeddings(self, value):
2048
- self.model.decoder.embed_tokens = value
2049
-
2050
- def get_output_embeddings(self):
2051
- return self.lm_head
2052
-
2053
- def set_output_embeddings(self, new_embeddings):
2054
- self.lm_head = new_embeddings
2055
-
2056
- def set_decoder(self, decoder):
2057
- self.model.decoder = decoder
2058
-
2059
- def get_decoder(self):
2060
- return self.model.decoder
2061
-
2062
- def forward(
2063
- self,
2064
- input_ids=None,
2065
- attention_mask=None,
2066
- encoder_hidden_states=None,
2067
- encoder_attention_mask=None,
2068
- head_mask=None,
2069
- cross_attn_head_mask=None,
2070
- past_key_values=None,
2071
- inputs_embeds=None,
2072
- labels=None,
2073
- use_cache=None,
2074
- output_attentions=None,
2075
- output_hidden_states=None,
2076
- return_dict=None,
2077
- ):
2078
-
2079
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
2080
- output_hidden_states = (
2081
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
2082
- )
2083
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
2084
-
2085
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
2086
- outputs = self.model.decoder(
2087
- input_ids=input_ids,
2088
- attention_mask=attention_mask,
2089
- encoder_hidden_states=encoder_hidden_states,
2090
- encoder_attention_mask=encoder_attention_mask,
2091
- head_mask=head_mask,
2092
- cross_attn_head_mask=cross_attn_head_mask,
2093
- past_key_values=past_key_values,
2094
- inputs_embeds=inputs_embeds,
2095
- use_cache=use_cache,
2096
- output_attentions=output_attentions,
2097
- output_hidden_states=output_hidden_states,
2098
- return_dict=return_dict,
2099
- )
2100
-
2101
- logits = self.lm_head(outputs[0])
2102
-
2103
- loss = None
2104
- if labels is not None:
2105
- loss_fct = CrossEntropyLoss()
2106
- loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
2107
-
2108
- if not return_dict:
2109
- output = (logits,) + outputs[1:]
2110
- return (loss,) + output if loss is not None else output
2111
-
2112
- return CausalLMOutputWithCrossAttentions(
2113
- loss=loss,
2114
- logits=logits,
2115
- past_key_values=outputs.past_key_values,
2116
- hidden_states=outputs.hidden_states,
2117
- attentions=outputs.attentions,
2118
- cross_attentions=outputs.cross_attentions,
2119
- )
2120
-
2121
- def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
2122
- # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
2123
- if attention_mask is None:
2124
- attention_mask = input_ids.new_ones(input_ids.shape)
2125
-
2126
- if past:
2127
- input_ids = input_ids[:, -1:]
2128
- # first step, decoder_cached_states are empty
2129
- return {
2130
- "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
2131
- "attention_mask": attention_mask,
2132
- "past_key_values": past,
2133
- "use_cache": use_cache,
2134
- }
2135
-
2136
- @staticmethod
2137
- def _reorder_cache(past, beam_idx):
2138
- reordered_past = ()
2139
- for layer_past in past:
2140
- reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
2141
- return reordered_past
2142
-
2143
 
2144
  def str_to_class(classname):
2145
  return getattr(sys.modules[__name__], classname)
 
41
  ):
42
  """Constructs LSGConfig."""
43
  super().__init__(**kwargs)
 
 
44
 
45
  self.adaptive = adaptive
46
  self.auto_map = AUTO_MAP
 
53
  self.sparse_block_size = sparse_block_size
54
  self.sparsity_factor = sparsity_factor
55
  self.sparsity_type = sparsity_type
56
+
57
+ if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride"]:
58
+ logger.warning(
59
+ "[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride'], setting sparsity_type=None, computation will skip sparse attention")
60
+ self.sparsity_type = None
61
+
62
+ if self.sparsity_type == "stride":
63
+ if self.sparsity_factor > self.encoder_attention_heads:
64
+ logger.warning(
65
+ "[WARNING CONFIG]: sparsity_factor > encoder_attention_heads is not recommended for stride sparsity"
66
+ )
67
+
68
+ if self.num_global_tokens < 1:
69
+ logger.warning(
70
+ "[WARNING CONFIG]: num_global_tokens < 1 is not compatible, setting num_global_tokens=1"
71
+ )
72
+ self.num_global_tokens = 1
73
+ elif self.num_global_tokens > 512:
74
+ logger.warning(
75
+ "[WARNING CONFIG]: num_global_tokens > 512 is not compatible, setting num_global_tokens=512"
76
+ )
77
+ self.num_global_tokens = 512
78
 
79
+ if self.sparsity_factor > 0:
80
+ assert self.block_size % self.sparsity_factor == 0, "[ERROR CONFIG]: block_size must be divisible by sparsity_factor"
81
+ assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
82
+
83
 
84
  def shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id):
85
  """
 
232
  # Shape of blocks
233
  self.local_shapes = (self.block_size*3, self.block_size)
234
  if self.sparse_block_size and self.sparsity_factor > 0:
 
 
235
  self.sparse_shapes = (self.sparse_block_size*3, self.block_size//self.sparsity_factor)
236
 
237
  self.attention = BaseAttentionProduct(config)
 
415
  }
416
 
417
  self.sparsity_type = config.sparsity_type
418
+ self.get_sparse_elements = sparse_functions.get(self.sparsity_type, lambda x, y, z: (None, None, None))
 
 
 
 
 
 
419
 
420
  if config.sparsity_type == "lsh":
421
  self.lsh_num_pre_rounds = config.lsh_num_pre_rounds
422
+
423
  def get_sparse_tokens_with_norm(self, keys, values, mask):
424
 
425
  if self.sparsity_factor == 1:
426
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
427
 
428
  with torch.no_grad():
429
 
 
451
  def get_sparse_tokens_with_pooling(self, keys, values, mask):
452
 
453
  if self.sparsity_factor == 1:
454
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
455
 
456
  keys = self.chunk(keys, self.sparsity_factor)
457
  values = self.chunk(values, self.sparsity_factor)
 
473
  def get_sparse_tokens_with_stride(self, keys, values, mask):
474
 
475
  if self.sparsity_factor == 1:
476
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
477
 
478
  n, h, t, d = keys.size()
479
  sparse_idx = torch.arange(t // self.sparsity_factor, device=keys.device) * self.sparsity_factor
480
  sparse_idx = sparse_idx.reshape(1, 1, -1, 1) + (torch.arange(h, device=keys.device) % self.sparsity_factor).reshape(1, h, 1, 1)
481
  sparse_idx = sparse_idx.expand(n, h, -1, 1)
482
 
483
+ """
484
+ t, b = self.block_size, t // self.block_size
485
+ sparse_idx = torch.arange(t // self.sparsity_factor, device=keys.device) * self.sparsity_factor
486
+ sparse_idx = sparse_idx.reshape(1, 1, 1, -1, 1) + (torch.arange(h, device=keys.device) % self.sparsity_factor).reshape(1, h, 1, 1, 1)
487
+ sparse_idx = sparse_idx + torch.arange(b, device=keys.device).reshape(1, 1, -1, 1, 1) * t
488
+ sparse_idx = sparse_idx.reshape(1, h, -1, 1).expand(n, h, -1, 1)
489
+
490
+
491
+ t, b = self.block_size, t // self.block_size
492
+ sparse_idx = torch.arange(t // self.sparsity_factor, device=keys.device)
493
+ sparse_idx = sparse_idx.reshape(1, 1, 1, -1, 1) + torch.arange(h, device=keys.device).reshape(1, h, 1, 1, 1) * (t // self.sparsity_factor)
494
+ sparse_idx = (sparse_idx % t)
495
+ #sparse_idx[..., -t//2:, :] = (sparse_idx[..., -t//2:, :] + t//2) % t
496
+ sparse_idx = sparse_idx + torch.arange(b, device=keys.device).reshape(1, 1, -1, 1, 1) * t
497
+ sparse_idx = sparse_idx.reshape(1, h, -1, 1).expand(n, h, -1, 1)
498
+ """
499
+
500
  keys = keys.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
501
  values = values.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
502
  mask = mask.expand(-1, h, -1, -1).transpose(-1, -2).gather(dim=-2, index=sparse_idx).transpose(-1, -2)
 
506
  def get_sparse_tokens_with_lsh(self, keys, values, mask):
507
 
508
  if self.sparsity_factor == 1:
509
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
510
 
511
  block_size = min(self.block_size, self.sparse_block_size)
512
  keys = self.chunk(keys, block_size)
 
523
  extra_factor = 1
524
 
525
  for _ in range(self.lsh_num_pre_rounds):
526
+ keys, values, mask = self.lsh_round(keys, values, mask, t*extra_factor)
527
 
528
+ keys, values, mask = self.lsh_round(keys, values, mask, t//self.sparsity_factor)
529
  keys /= mask + 1e-8
530
  values /= mask + 1e-8
531
 
 
533
 
534
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
535
 
536
+ def lsh_round(self, keys, values, mask, output_size):
537
 
538
  with torch.no_grad():
539
 
 
1337
  self.padding_idx = config.pad_token_id
1338
  self.max_target_positions = config.max_position_embeddings
1339
  self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
1340
+ self.adaptive = config.adaptive
1341
 
1342
  if embed_tokens is not None:
1343
  self.embed_tokens = embed_tokens
 
1380
 
1381
  return combined_attention_mask
1382
 
1383
+ def resize_inputs(self, inputs_embeds, attention_mask):
1384
+ pad = 0
1385
+
1386
+ max_len = int(attention_mask.sum(dim=-1).max())
1387
+ pad = attention_mask.size()[-1] - max_len
1388
+ inputs_embeds = inputs_embeds[:, :max_len]
1389
+ attention_mask = attention_mask[..., :max_len]
1390
+ return pad, inputs_embeds, attention_mask
1391
+
1392
  def forward(
1393
  self,
1394
  input_ids=None,
 
1429
  if inputs_embeds is None:
1430
  inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
1431
 
1432
+ # Resize to reduce computation
1433
+ pad = 0
1434
+ if self.adaptive:
1435
+ if attention_mask is not None:
1436
+ pad, inputs_embeds, attention_mask = self.resize_inputs(inputs_embeds, attention_mask)
1437
+ input_shape = inputs_embeds.size()[:-1]
1438
+ if encoder_attention_mask is not None:
1439
+ _, encoder_hidden_states, encoder_attention_mask = self.resize_inputs(encoder_hidden_states, encoder_attention_mask)
1440
 
1441
  attention_mask = self._prepare_decoder_attention_mask(
1442
  attention_mask, input_shape, inputs_embeds, past_key_values_length
 
1530
  if encoder_hidden_states is not None:
1531
  all_cross_attentions += (layer_outputs[2],)
1532
 
1533
+ # Resize to original shape
1534
+ hidden_states = torch.nn.functional.pad(hidden_states.transpose(-1, -2), pad=(0, pad), value=0).transpose(-1, -2)
1535
+
1536
  # add hidden states from the last decoder layer
1537
  if output_hidden_states:
1538
  all_hidden_states += (hidden_states,)
 
1669
  )
1670
 
1671
 
1672
+ class LSGBartForConditionalGeneration(BartForConditionalGeneration, LSGBartPretrainedModel):
1673
 
1674
  base_model_prefix = "model"
1675
  _keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head\.weight"]
1676
 
1677
  def __init__(self, config):
1678
 
1679
+ LSGBartPretrainedModel.__init__(self, config)
1680
  self.model = LSGBartModel(config)
1681
  self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
1682
  self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
 
1684
  # Initialize weights and apply final processing
1685
  self.post_init()
1686
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1687
 
1688
+ class LSGBartForSequenceClassification(BartForSequenceClassification, LSGBartPretrainedModel):
1689
 
1690
+ def __init__(self, config: LSGBartConfig, **kwargs):
1691
 
1692
+ LSGBartPretrainedModel.__init__(self, config, **kwargs)
1693
  self.model = LSGBartModel(config)
1694
  self.classification_head = LSGBartClassificationHead(
1695
  config.d_model,
 
1700
  self.model._init_weights(self.classification_head.dense)
1701
  self.model._init_weights(self.classification_head.out_proj)
1702
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1703
 
1704
+ class LSGBartForQuestionAnswering(BartForQuestionAnswering, LSGBartPretrainedModel):
1705
 
1706
+ def __init__(self, config: LSGBartConfig):
1707
 
1708
+ LSGBartPretrainedModel.__init__(self, config)
 
 
1709
 
1710
  config.num_labels = 2
1711
  self.num_labels = config.num_labels
 
1715
 
1716
  self.model._init_weights(self.qa_outputs)
1717
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1718
 
1719
  class LSGBartDecoderWrapper(LSGBartPretrainedModel):
1720
  """
 
1722
  used in combination with the :class:`~transformers.EncoderDecoderModel` framework.
1723
  """
1724
 
1725
+ def __init__(self, config: LSGBartConfig):
1726
  super().__init__(config)
1727
  self.decoder = LSGBartDecoder(config)
1728
 
 
1730
  return self.decoder(*args, **kwargs)
1731
 
1732
 
1733
+ class LSGBartForCausalLM(BartForCausalLM, LSGBartPretrainedModel):
1734
 
1735
+ def __init__(self, config: LSGBartConfig):
1736
 
 
1737
  config = copy.deepcopy(config)
1738
  config.is_decoder = True
1739
  config.is_encoder_decoder = False
1740
+ LSGBartPretrainedModel.__init__(self, config)
1741
  self.model = LSGBartDecoderWrapper(config)
1742
 
1743
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 
1745
  # Initialize weights and apply final processing
1746
  self.post_init()
1747
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1748
 
1749
  def str_to_class(classname):
1750
  return getattr(sys.modules[__name__], classname)
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:12252a53bd8fdbb6daafd2118a44789c0d3d37f01c62ccde0d10a92142e44a72
3
  size 578416695
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:88af6fadc19698eaa5d49e63aa969487846fbdfb41852afe199350a98d04801d
3
  size 578416695