ccdv commited on
Commit
da22152
1 Parent(s): fba3c44
README.md ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ tags:
5
+ - summarization
6
+ datasets:
7
+ - multi_news
8
+ metrics:
9
+ - rouge
10
+ model-index:
11
+ - name: ccdv/lsg-bart-base-4096-multinews
12
+ results: []
13
+ ---
14
+
15
+ <!-- This model card has been generated automatically according to the information the Trainer had access to. You
16
+ should probably proofread and complete it, then remove this comment. -->
17
+
18
+ **This model relies on a custom modeling file, you need to add trust_remote_code=True**\
19
+ **See [\#13467](https://github.com/huggingface/transformers/pull/13467)**
20
+
21
+ ```python
22
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
23
+
24
+ tokenizer = AutoTokenizer.from_pretrained("ccdv/lsg-bart-base-4096-multinews", trust_remote_code=True)
25
+ model = AutoModelForSeq2SeqLM.from_pretrained("ccdv/lsg-bart-base-4096-multinews", trust_remote_code=True)
26
+
27
+ text = "Replace by what you want."
28
+ pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer, device=0)
29
+ generated_text = pipe(
30
+ text,
31
+ truncation=True,
32
+ max_length=64,
33
+ no_repeat_ngram_size=7,
34
+ num_beams=2,
35
+ early_stopping=True
36
+ )
37
+ ```
38
+
39
+ # ccdv/lsg-bart-base-4096-multinews
40
+
41
+ This model is a fine-tuned version of [ccdv/lsg-bart-base-4096](https://huggingface.co/ccdv/lsg-bart-base-4096) on the multi_news default dataset. \
42
+ It achieves the following results on the test set:
43
+
44
+ | Length | Sparse Type | Block Size | Sparsity | Connexions | R1 | R2 | RL | RLsum |
45
+ |:------ |:------------ |:---------- |:-------- | :--------- |:----- |:----- |:----- |:----- |
46
+ | 4096 | Local | 256 | 0 | 768 | 47.10 | 18.94 | 25.22 | 43.13 |
47
+ | 4096 | Local | 128 | 0 | 384 | 46.73 | 18.79 | 25.13 | 42.76 |
48
+ | 4096 | Pooling | 128 | 4 | 644 | 46.83 | 18.87 | 25.23 | 42.86 |
49
+ | 4096 | Stride | 128 | 4 | 644 | 46.83 | 18.68 | 24.98 | 42.88 |
50
+ | 4096 | Block Stride | 128 | 4 | 644 | 46.83 | 18.72 | 25.06 | 42.88 |
51
+ | 4096 | Norm | 128 | 4 | 644 | 46.74 | 18.60 | 24.93 | 42.79 |
52
+ | 4096 | LSH | 128 | 4 | 644 | 46.74 | 18.82 | 25.19 | 42.77 |
53
+
54
+ With smaller block size (lower ressources):
55
+
56
+ | Length | Sparse Type | Block Size | Sparsity | Connexions | R1 | R2 | RL | RLsum |
57
+ |:------ |:------------ |:---------- |:-------- | :--------- |:----- |:----- |:----- |:----- |
58
+ | 4096 | Pooling | 32 | 4 | 160 | 44.77 | 17.31 | 24.16 | 40.86 |
59
+ | 4096 | Stride | 32 | 4 | 160 | 45.29 | 17.81 | 24.45 | 41.40 |
60
+ | 4096 | Block Stride | 32 | 4 | 160 | 45.39 | 17.86 | 24.51 | 41.43 |
61
+ | 4096 | Norm | 32 | 4 | 160 | 44.65 | 17.25 | 24.09 | 40.76 |
62
+ | 4096 | LSH | 32 | 4 | 160 | 44.44 | 17.20 | 24.00 | 40.57 |
63
+
64
+ ## Model description
65
+ The model relies on Local-Sparse-Global attention to handle long sequences:
66
+ ![attn](attn.png)
67
+
68
+ The model has about ~145 millions parameters (6 encoder layers - 6 decoder layers). \
69
+ The model is warm started from BART-base, converted to handle long sequences (encoder only) and fine tuned. \
70
+
71
+ ## Intended uses & limitations
72
+
73
+ More information needed
74
+
75
+ ## Training and evaluation data
76
+
77
+ More information needed
78
+
79
+ ## Training procedure
80
+
81
+ ### Training hyperparameters
82
+
83
+ The following hyperparameters were used during training:
84
+ - learning_rate: 8e-05
85
+ - train_batch_size: 8
86
+ - seed: 42
87
+ - gradient_accumulation_steps: 4
88
+ - total_train_batch_size: 32
89
+ - optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
90
+ - lr_scheduler_type: linear
91
+ - lr_scheduler_warmup_ratio: 0.1
92
+ - num_epochs: 12.0
93
+
94
+ ### Generate hyperparameters
95
+
96
+ The following hyperparameters were used during generation:
97
+ - dataset_name: multi_news
98
+ - dataset_config_name: default
99
+ - eval_batch_size: 8
100
+ - eval_samples: 5622
101
+ - early_stopping: True
102
+ - ignore_pad_token_for_loss: True
103
+ - length_penalty: 2.0
104
+ - max_length: 320
105
+ - min_length: 32
106
+ - num_beams: 5
107
+ - no_repeat_ngram_size: None
108
+ - seed: 123
109
+
110
+ ### Framework versions
111
+
112
+ - Transformers 4.18.0
113
+ - Pytorch 1.10.1+cu102
114
+ - Datasets 2.1.0
115
+ - Tokenizers 0.11.6
config.json ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "models/ccdv/lsg-bart-base-4096-multinews",
3
+ "activation_dropout": 0.1,
4
+ "activation_function": "gelu",
5
+ "adaptive": true,
6
+ "add_bias_logits": false,
7
+ "add_final_layer_norm": false,
8
+ "architectures": [
9
+ "LSGBartForConditionalGeneration"
10
+ ],
11
+ "attention_dropout": 0.1,
12
+ "auto_map": {
13
+ "AutoConfig": "modeling_lsg_bart.LSGBartConfig",
14
+ "AutoModel": "modeling_lsg_bart.LSGBartModel",
15
+ "AutoModelForCausalLM": "modeling_lsg_bart.LSGBartForCausalLM",
16
+ "AutoModelForQuestionAnswering": "modeling_lsg_bart.LSGBartForQuestionAnswering",
17
+ "AutoModelForSeq2SeqLM": "modeling_lsg_bart.LSGBartForConditionalGeneration",
18
+ "AutoModelForSequenceClassification": "modeling_lsg_bart.LSGBartForSequenceClassification"
19
+ },
20
+ "base_model_prefix": "lsg",
21
+ "block_size": 256,
22
+ "bos_token_id": 0,
23
+ "classif_dropout": 0.1,
24
+ "classifier_dropout": 0.0,
25
+ "d_model": 768,
26
+ "decoder_attention_heads": 12,
27
+ "decoder_ffn_dim": 3072,
28
+ "decoder_layerdrop": 0.0,
29
+ "decoder_layers": 6,
30
+ "decoder_start_token_id": 2,
31
+ "dropout": 0.1,
32
+ "early_stopping": true,
33
+ "encoder_attention_heads": 12,
34
+ "encoder_ffn_dim": 3072,
35
+ "encoder_layerdrop": 0.0,
36
+ "encoder_layers": 6,
37
+ "eos_token_id": 2,
38
+ "forced_bos_token_id": 0,
39
+ "forced_eos_token_id": 2,
40
+ "gradient_checkpointing": false,
41
+ "id2label": {
42
+ "0": "LABEL_0",
43
+ "1": "LABEL_1",
44
+ "2": "LABEL_2"
45
+ },
46
+ "init_std": 0.02,
47
+ "is_encoder_decoder": true,
48
+ "label2id": {
49
+ "LABEL_0": 0,
50
+ "LABEL_1": 1,
51
+ "LABEL_2": 2
52
+ },
53
+ "length_penalty": 2.0,
54
+ "lsh_num_pre_rounds": 1,
55
+ "max_length": 320,
56
+ "max_position_embeddings": 4096,
57
+ "min_length": 32,
58
+ "model_type": "bart",
59
+ "no_repeat_ngram_size": null,
60
+ "normalize_before": false,
61
+ "normalize_embedding": true,
62
+ "num_beams": 5,
63
+ "num_global_tokens": 1,
64
+ "num_hidden_layers": 6,
65
+ "pad_token_id": 1,
66
+ "pass_global_tokens_to_decoder": true,
67
+ "pool_with_global": true,
68
+ "scale_embedding": false,
69
+ "sparse_block_size": 0,
70
+ "sparsity_factor": 2,
71
+ "sparsity_type": "none",
72
+ "task_specific_params": {
73
+ "summarization": {
74
+ "length_penalty": 1.0,
75
+ "max_length": 128,
76
+ "min_length": 12,
77
+ "num_beams": 4
78
+ },
79
+ "summarization_cnn": {
80
+ "length_penalty": 2.0,
81
+ "max_length": 142,
82
+ "min_length": 56,
83
+ "num_beams": 4
84
+ },
85
+ "summarization_xsum": {
86
+ "length_penalty": 1.0,
87
+ "max_length": 62,
88
+ "min_length": 11,
89
+ "num_beams": 6
90
+ }
91
+ },
92
+ "torch_dtype": "float32",
93
+ "transformers_version": "4.18.0",
94
+ "use_cache": true,
95
+ "vocab_size": 50265
96
+ }
merges.txt ADDED
The diff for this file is too large to render. See raw diff
modeling_lsg_bart.py ADDED
@@ -0,0 +1,1759 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from logging import warn
2
+ import torch
3
+ from transformers.models.bart.modeling_bart import *
4
+ from transformers.models.bart.modeling_bart import _expand_mask
5
+ import torch.nn as nn
6
+ from torch.nn import BCEWithLogitsLoss
7
+ import sys
8
+
9
+ AUTO_MAP = {
10
+ "AutoModel": "modeling_lsg_bart.LSGBartModel",
11
+ "AutoModelForCausalLM": "modeling_lsg_bart.LSGBartForCausalLM",
12
+ "AutoModelForQuestionAnswering": "modeling_lsg_bart.LSGBartForQuestionAnswering",
13
+ "AutoModelForSequenceClassification": "modeling_lsg_bart.LSGBartForSequenceClassification",
14
+ "AutoModelForSeq2SeqLM": "modeling_lsg_bart.LSGBartForConditionalGeneration"
15
+ }
16
+
17
+ class LSGBartConfig(BartConfig):
18
+ """
19
+ This class overrides :class:`~transformers.RobertaConfig`. Please check the superclass for the appropriate
20
+ documentation alongside usage examples.
21
+ """
22
+
23
+ base_model_prefix = "lsg"
24
+ model_type = "bart"
25
+ keys_to_ignore_at_inference = ["past_key_values"]
26
+ attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
27
+
28
+ def __init__(
29
+ self,
30
+ adaptive=True,
31
+ base_model_prefix="lsg",
32
+ block_size=128,
33
+ lsh_num_pre_rounds=1,
34
+ num_global_tokens=1,
35
+ pass_global_tokens_to_decoder=True,
36
+ pool_with_global=True,
37
+ sparse_block_size=128,
38
+ sparsity_factor=2,
39
+ sparsity_type="norm",
40
+ **kwargs
41
+ ):
42
+ """Constructs LSGConfig."""
43
+ super().__init__(**kwargs)
44
+
45
+ self.adaptive = adaptive
46
+ self.auto_map = AUTO_MAP
47
+ self.base_model_prefix = base_model_prefix
48
+ self.block_size = block_size
49
+ self.lsh_num_pre_rounds = lsh_num_pre_rounds
50
+ self.num_global_tokens = num_global_tokens
51
+ self.pass_global_tokens_to_decoder = pass_global_tokens_to_decoder
52
+ self.pool_with_global = pool_with_global
53
+ self.sparse_block_size = sparse_block_size
54
+ self.sparsity_factor = sparsity_factor
55
+ self.sparsity_type = sparsity_type
56
+
57
+ if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride"]:
58
+ logger.warning(
59
+ "[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride'], setting sparsity_type=None, computation will skip sparse attention")
60
+ self.sparsity_type = None
61
+
62
+ if self.sparsity_type == "stride":
63
+ if self.sparsity_factor > self.encoder_attention_heads:
64
+ logger.warning(
65
+ "[WARNING CONFIG]: sparsity_factor > encoder_attention_heads is not recommended for stride sparsity"
66
+ )
67
+
68
+ if self.num_global_tokens < 1:
69
+ logger.warning(
70
+ "[WARNING CONFIG]: num_global_tokens < 1 is not compatible, setting num_global_tokens=1"
71
+ )
72
+ self.num_global_tokens = 1
73
+ elif self.num_global_tokens > 512:
74
+ logger.warning(
75
+ "[WARNING CONFIG]: num_global_tokens > 512 is not compatible, setting num_global_tokens=512"
76
+ )
77
+ self.num_global_tokens = 512
78
+
79
+ if self.sparsity_factor > 0:
80
+ assert self.block_size % self.sparsity_factor == 0, "[ERROR CONFIG]: block_size must be divisible by sparsity_factor"
81
+ assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
82
+
83
+
84
+ def shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id):
85
+ """
86
+ Shift input ids one token to the right.
87
+ """
88
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
89
+ shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
90
+ shifted_input_ids[:, 0] = decoder_start_token_id
91
+
92
+ if pad_token_id is None:
93
+ raise ValueError("self.model.config.pad_token_id has to be defined.")
94
+ # replace possible -100 values in labels by `pad_token_id`
95
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
96
+
97
+ return shifted_input_ids
98
+
99
+
100
+ def _make_causal_mask(input_ids_shape, dtype, past_key_values_length=0):
101
+ """
102
+ Make causal mask used for bi-directional self-attention.
103
+ """
104
+ bsz, tgt_len = input_ids_shape
105
+ mask = torch.full((tgt_len, tgt_len), float("-inf"))
106
+ mask_cond = torch.arange(mask.size(-1))
107
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
108
+ mask = mask.to(dtype)
109
+
110
+ if past_key_values_length > 0:
111
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1)
112
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
113
+
114
+
115
+ def _expand_mask(mask, dtype, tgt_len=None):
116
+ """
117
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
118
+ """
119
+ bsz, src_len = mask.size()
120
+ tgt_len = tgt_len if tgt_len is not None else src_len
121
+
122
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
123
+
124
+ inverted_mask = 1.0 - expanded_mask
125
+
126
+ return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
127
+
128
+
129
+ class BaseSelfAttention(nn.Module):
130
+
131
+ def __init__(
132
+ self,
133
+ embed_dim,
134
+ num_heads,
135
+ dropout=0.0,
136
+ is_decoder=False,
137
+ bias=True,
138
+ ):
139
+
140
+ super().__init__()
141
+ self.embed_dim = embed_dim
142
+ self.num_heads = num_heads
143
+ self.dropout = dropout
144
+ self.head_dim = embed_dim // num_heads
145
+
146
+ if (self.head_dim * num_heads) != self.embed_dim:
147
+ raise ValueError(
148
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
149
+ f" and `num_heads`: {num_heads})."
150
+ )
151
+ self.scaling = self.head_dim ** -0.5
152
+ self.is_decoder = is_decoder
153
+
154
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
155
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
156
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
157
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
158
+
159
+ def transpose_for_scores(self, x):
160
+ new_x_shape = x.size()[:-1] + (
161
+ self.num_heads,
162
+ self.head_dim,
163
+ )
164
+ x = x.view(*new_x_shape)
165
+ return x.permute(0, 2, 1, 3)
166
+
167
+ def reshape_output(self, context_layer):
168
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
169
+ new_context_layer_shape = context_layer.size()[:-2] + (self.embed_dim,)
170
+ return context_layer.view(*new_context_layer_shape)
171
+
172
+ def project_QKV(self, hidden_states):
173
+
174
+ query_layer = self.transpose_for_scores(self.q_proj(hidden_states))
175
+ key_layer = self.transpose_for_scores(self.k_proj(hidden_states))
176
+ value_layer = self.transpose_for_scores(self.v_proj(hidden_states))
177
+ return query_layer, key_layer, value_layer
178
+
179
+
180
+ class BaseAttentionProduct(nn.Module):
181
+
182
+ def __init__(self, config):
183
+ """
184
+ Compute attention: softmax(Q @ K.T) @ V
185
+ """
186
+ super().__init__()
187
+ self.dropout = nn.Dropout(config.attention_dropout)
188
+
189
+ def forward(self, query_layer, key_layer, value_layer, attention_mask=None):
190
+
191
+ d = query_layer.shape[-1]
192
+
193
+ # Take the dot product between "query" and "key" to get the raw attention scores.
194
+ attention_scores = query_layer @ key_layer.transpose(-1, -2) / math.sqrt(d)
195
+
196
+ del query_layer
197
+ del key_layer
198
+
199
+ if attention_mask is not None:
200
+ # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)
201
+ attention_scores = attention_scores + attention_mask
202
+ del attention_mask
203
+
204
+ # Normalize the attention scores to probabilities.
205
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
206
+
207
+ # This is actually dropping out entire tokens to attend to, which might
208
+ # seem a bit unusual, but is taken from the original Transformer paper.
209
+ context_layer = self.dropout(attention_probs) @ value_layer
210
+
211
+ return context_layer
212
+
213
+
214
+ class LSGAttentionProduct(nn.Module):
215
+
216
+ def __init__(self, config, block_size=None, sparse_block_size=None, sparsity_factor=4):
217
+ """
218
+ Compute block or overlapping blocks attention products
219
+ """
220
+ super().__init__()
221
+
222
+ self.block_size = block_size
223
+ self.sparse_block_size = sparse_block_size
224
+ self.sparsity_factor = sparsity_factor
225
+
226
+ if self.block_size is None:
227
+ self.block_size = config.block_size
228
+
229
+ if self.sparse_block_size is None:
230
+ self.sparse_block_size = config.sparse_block_size
231
+
232
+ # Shape of blocks
233
+ self.local_shapes = (self.block_size*3, self.block_size)
234
+ if self.sparse_block_size and self.sparsity_factor > 0:
235
+ self.sparse_shapes = (self.sparse_block_size*3, self.block_size//self.sparsity_factor)
236
+
237
+ self.attention = BaseAttentionProduct(config)
238
+
239
+ def build_lsg_inputs(self, hidden_states, sparse_hidden_states, global_hidden_states, is_attn_mask=False):
240
+
241
+ # Build local tokens
242
+ local_hidden_states = self.reshape_to_local_block(hidden_states, is_attn_mask)
243
+ del hidden_states
244
+
245
+ # Build sparse tokens
246
+ if sparse_hidden_states is not None:
247
+ sparse_hidden_states = self.reshape_to_sparse_block(sparse_hidden_states, is_attn_mask)
248
+
249
+ return self.cat_global_sparse_local_tokens(global_hidden_states, sparse_hidden_states, local_hidden_states)
250
+
251
+ def forward(
252
+ self,
253
+ query_layer,
254
+ key_layer,
255
+ value_layer,
256
+ attention_mask=None,
257
+ sparse_key=None,
258
+ sparse_value=None,
259
+ sparse_mask=None,
260
+ global_key=None,
261
+ global_value=None,
262
+ global_mask=None
263
+ ):
264
+
265
+ # Input batch, heads, length, hidden_size
266
+ n, h, t, d = query_layer.size()
267
+ n_blocks = t // self.block_size
268
+ assert t % self.block_size == 0
269
+
270
+ key_layer = self.build_lsg_inputs(
271
+ key_layer,
272
+ sparse_key,
273
+ global_key
274
+ )
275
+ del sparse_key
276
+ del global_key
277
+
278
+ value_layer = self.build_lsg_inputs(
279
+ value_layer,
280
+ sparse_value,
281
+ global_value
282
+ )
283
+ del sparse_value
284
+ del global_value
285
+
286
+ attention_mask = self.build_lsg_inputs(
287
+ attention_mask,
288
+ sparse_mask,
289
+ global_mask.transpose(-1, -2),
290
+ is_attn_mask=True
291
+ ).transpose(-1, -2)
292
+ del sparse_mask
293
+ del global_mask
294
+
295
+ # expect (..., t, d) shape
296
+ # Compute attention
297
+ context_layer = self.attention(
298
+ query_layer=self.chunk(query_layer, n_blocks),
299
+ key_layer=key_layer,
300
+ value_layer=value_layer,
301
+ attention_mask=attention_mask
302
+ )
303
+
304
+ return context_layer.reshape(n, h, -1, d)
305
+
306
+ def reshape_to_local_block(self, hidden_states, is_attn_mask=False):
307
+
308
+ size, step = self.local_shapes
309
+ s = (size - step) // 2
310
+
311
+ # Pad before block reshaping
312
+ if is_attn_mask:
313
+ pad_value = -10000
314
+ hidden_states = hidden_states.transpose(-1, -2)
315
+ else:
316
+ pad_value = 0
317
+
318
+ hidden_states = torch.nn.functional.pad(
319
+ hidden_states.transpose(-1, -2),
320
+ pad=(s, s),
321
+ value=pad_value
322
+ ).transpose(-1, -2)
323
+
324
+ # Make blocks
325
+ hidden_states = hidden_states.unfold(-2, size=size, step=step).transpose(-1, -2)
326
+
327
+ return hidden_states
328
+
329
+ def reshape_to_sparse_block(self, hidden_states, is_attn_mask=False):
330
+
331
+ size, step = self.sparse_shapes
332
+
333
+ # In case of odd case
334
+ odd_offset = (step % 2)
335
+
336
+ # n, h, t, d*2 + 1
337
+ size = size*2
338
+ s = (size - step) // 2 + odd_offset
339
+
340
+ # Pad before block reshaping
341
+ if is_attn_mask:
342
+ pad_value = -10000
343
+ hidden_states = hidden_states.transpose(-1, -2)
344
+ else:
345
+ pad_value = 0
346
+
347
+ hidden_states = torch.nn.functional.pad(
348
+ hidden_states.transpose(-1, -2),
349
+ pad=(s, s),
350
+ value=pad_value
351
+ ).transpose(-1, -2)
352
+
353
+ # Make blocks
354
+ hidden_states = hidden_states.unfold(-2, size=size, step=step).transpose(-1, -2)
355
+
356
+ # Fix case where block_size == sparsify_factor
357
+ if odd_offset:
358
+ hidden_states = hidden_states[..., :-1, :, :]
359
+
360
+ # Indexes for selection
361
+ u = (size - self.block_size * 3 // self.sparsity_factor) // 2 + odd_offset
362
+ s = self.sparse_block_size
363
+
364
+ u_ = u + odd_offset
365
+ return torch.cat([hidden_states[..., u-s:u, :], hidden_states[..., -u_:-u_+s, :]], dim=-2)
366
+
367
+ def cat_global_sparse_local_tokens(self, x_global, x_sparse=None, x_local=None, dim=-2):
368
+
369
+ n, h, b, t, d = x_local.size()
370
+ x_global = x_global.unsqueeze(-3).expand(-1, -1, b, -1, -1)
371
+ if x_sparse is not None:
372
+ return torch.cat([x_global, x_sparse, x_local], dim=dim)
373
+ return torch.cat([x_global, x_local], dim=dim)
374
+
375
+ def chunk(self, x, n_blocks):
376
+
377
+ t, d = x.size()[-2:]
378
+ return x.reshape(*x.size()[:-2], n_blocks, -1, d)
379
+
380
+
381
+ class LSGBartEncoderAttention(BaseSelfAttention):
382
+ '''
383
+ Compute local attention with overlapping blocs
384
+ Use global attention for tokens with highest norm
385
+ '''
386
+ def __init__(
387
+ self,
388
+ config,
389
+ embed_dim,
390
+ num_heads,
391
+ dropout
392
+ ):
393
+
394
+ super().__init__(embed_dim, num_heads, dropout)
395
+
396
+ self.block_size = config.block_size
397
+ self.sparse_block_size = config.sparse_block_size
398
+ self.num_global_tokens = config.num_global_tokens
399
+ self.sparsity_factor = config.sparsity_factor
400
+
401
+ self.attention = LSGAttentionProduct(
402
+ config,
403
+ block_size=config.block_size,
404
+ sparse_block_size=config.sparse_block_size,
405
+ sparsity_factor=self.sparsity_factor,
406
+ )
407
+
408
+ self.full_attention = BaseAttentionProduct(config)
409
+
410
+ sparse_functions = {
411
+ "norm": self.get_sparse_tokens_with_norm,
412
+ "pooling": self.get_sparse_tokens_with_pooling,
413
+ "lsh": self.get_sparse_tokens_with_lsh,
414
+ "stride": self.get_sparse_tokens_with_stride,
415
+ }
416
+
417
+ self.sparsity_type = config.sparsity_type
418
+ self.get_sparse_elements = sparse_functions.get(self.sparsity_type, lambda x, y, z: (None, None, None))
419
+
420
+ if config.sparsity_type == "lsh":
421
+ self.lsh_num_pre_rounds = config.lsh_num_pre_rounds
422
+
423
+ def get_sparse_tokens_with_norm(self, keys, values, mask):
424
+
425
+ if self.sparsity_factor == 1:
426
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
427
+
428
+ with torch.no_grad():
429
+
430
+ block_size = min(self.block_size, self.sparse_block_size)
431
+ key_norm = keys.detach().norm(dim=-1, keepdim=True)
432
+ key_norm = key_norm * ~mask.transpose(-1, -2).bool()
433
+ key_norm = self.chunk(key_norm, block_size)
434
+
435
+ n, h, b, t, d = key_norm.size()
436
+
437
+ idx = key_norm.argsort(dim=-2)
438
+ del key_norm
439
+ idx += (torch.arange(b, device=keys.device)*t).reshape(1, 1, b, 1, 1)
440
+
441
+ split = (t - block_size // self.sparsity_factor, block_size // self.sparsity_factor)
442
+ sparse_idx = idx.split(split, -2)[-1].reshape(n, h, -1, 1)
443
+
444
+ d = keys.size()[-1]
445
+ keys = keys.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
446
+ values = values.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
447
+ mask = mask.expand(-1, h, -1, -1).transpose(-1, -2).gather(dim=-2, index=sparse_idx).transpose(-1, -2)
448
+
449
+ return keys, values, mask
450
+
451
+ def get_sparse_tokens_with_pooling(self, keys, values, mask):
452
+
453
+ if self.sparsity_factor == 1:
454
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
455
+
456
+ keys = self.chunk(keys, self.sparsity_factor)
457
+ values = self.chunk(values, self.sparsity_factor)
458
+
459
+ n, h, b, t, d = keys.size()
460
+ mask = mask.reshape(n, 1, b, 1, t)
461
+ mask = ~mask.transpose(-1, -2).bool()
462
+
463
+ keys = keys * mask
464
+ values = values * mask
465
+
466
+ mask = mask.sum(dim=-2)
467
+ keys = keys.sum(dim=-2) / (mask + 1e-6)
468
+ values = values.sum(dim=-2) / (mask + 1e-6)
469
+
470
+ mask = - (1. - mask.clamp(0, 1)) * 1e4
471
+ return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
472
+
473
+ def get_sparse_tokens_with_stride(self, keys, values, mask):
474
+
475
+ if self.sparsity_factor == 1:
476
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
477
+
478
+ n, h, t, d = keys.size()
479