ccdv commited on
Commit
17833c9
1 Parent(s): 5f68e69
.gitattributes CHANGED
@@ -20,6 +20,7 @@
20
  *.pt filter=lfs diff=lfs merge=lfs -text
21
  *.pth filter=lfs diff=lfs merge=lfs -text
22
  *.rar filter=lfs diff=lfs merge=lfs -text
 
23
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
24
  *.tar.* filter=lfs diff=lfs merge=lfs -text
25
  *.tflite filter=lfs diff=lfs merge=lfs -text
 
20
  *.pt filter=lfs diff=lfs merge=lfs -text
21
  *.pth filter=lfs diff=lfs merge=lfs -text
22
  *.rar filter=lfs diff=lfs merge=lfs -text
23
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
24
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
25
  *.tar.* filter=lfs diff=lfs merge=lfs -text
26
  *.tflite filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ tags:
4
+ - long context
5
+ pipeline_tag: fill-mask
6
+ ---
7
+
8
+ # LSG model
9
+ **Transformers >= 4.18.0**\
10
+ **This model relies on a custom modeling file, you need to add trust_remote_code=True**\
11
+ **See [\#13467](https://github.com/huggingface/transformers/pull/13467)**
12
+
13
+ * [Usage](#usage)
14
+ * [Parameters](#parameters)
15
+ * [Sparse selection type](#sparse-selection-type)
16
+ * [Tasks](#tasks)
17
+ * [Training global tokens](#training-global-tokens)
18
+
19
+ This model is a small version of the [XLM-roberta-base](https://huggingface.co/xlm-roberta-base) model without additional pretraining yet. It uses the same number of parameters/layers and the same tokenizer.
20
+
21
+
22
+ This model can handle long sequences but faster and more efficiently than Longformer or BigBird (from Transformers) and relies on Local + Sparse + Global attention (LSG).
23
+
24
+
25
+ The model requires sequences whose length is a multiple of the block size. The model is "adaptive" and automatically pads the sequences if needed (adaptive=True in config). It is however recommended, thanks to the tokenizer, to truncate the inputs (truncation=True) and optionally to pad with a multiple of the block size (pad_to_multiple_of=...). \
26
+
27
+
28
+ Support encoder-decoder but I didnt test it extensively.\
29
+ Implemented in PyTorch.
30
+
31
+ ![attn](attn.png)
32
+
33
+ ## Usage
34
+ The model relies on a custom modeling file, you need to add trust_remote_code=True to use it.
35
+
36
+ ```python:
37
+ from transformers import AutoModel, AutoTokenizer
38
+
39
+ model = AutoModel.from_pretrained("ccdv/lsg-xlm-roberta-base-4096", trust_remote_code=True)
40
+ tokenizer = AutoTokenizer.from_pretrained("ccdv/lsg-xlm-roberta-base-4096")
41
+ ```
42
+
43
+ ## Parameters
44
+ You can change various parameters like :
45
+ * the number of global tokens (num_global_tokens=1)
46
+ * local block size (block_size=128)
47
+ * sparse block size (sparse_block_size=128)
48
+ * sparsity factor (sparsity_factor=2)
49
+ * mask_first_token (mask first token since it is redundant with the first global token)
50
+ * see config.json file
51
+
52
+ Default parameters work well in practice. If you are short on memory, reduce block sizes, increase sparsity factor and remove dropout in the attention score matrix.
53
+
54
+ ```python:
55
+ from transformers import AutoModel
56
+
57
+ model = AutoModel.from_pretrained("ccdv/lsg-xlm-roberta-base-4096",
58
+ trust_remote_code=True,
59
+ num_global_tokens=16,
60
+ block_size=64,
61
+ sparse_block_size=64,
62
+ attention_probs_dropout_prob=0.0
63
+ sparsity_factor=4,
64
+ sparsity_type="none",
65
+ mask_first_token=True
66
+ )
67
+ ```
68
+
69
+ ## Sparse selection type
70
+
71
+ There are 5 different sparse selection patterns. The best type is task dependent. \
72
+ Note that for sequences with length < 2*block_size, the type has no effect.
73
+
74
+ * sparsity_type="norm", select highest norm tokens
75
+ * Works best for a small sparsity_factor (2 to 4)
76
+ * Additional parameters:
77
+ * None
78
+ * sparsity_type="pooling", use average pooling to merge tokens
79
+ * Works best for a small sparsity_factor (2 to 4)
80
+ * Additional parameters:
81
+ * None
82
+ * sparsity_type="lsh", use the LSH algorithm to cluster similar tokens
83
+ * Works best for a large sparsity_factor (4+)
84
+ * LSH relies on random projections, thus inference may differ slightly with different seeds
85
+ * Additional parameters:
86
+ * lsg_num_pre_rounds=1, pre merge tokens n times before computing centroids
87
+ * sparsity_type="stride", use a striding mecanism per head
88
+ * Each head will use different tokens strided by sparsify_factor
89
+ * Not recommended if sparsify_factor > num_heads
90
+ * sparsity_type="block_stride", use a striding mecanism per head
91
+ * Each head will use block of tokens strided by sparsify_factor
92
+ * Not recommended if sparsify_factor > num_heads
93
+
94
+ ## Tasks
95
+ Fill mask example:
96
+ ```python:
97
+ from transformers import FillMaskPipeline, AutoModelForMaskedLM, AutoTokenizer
98
+
99
+ model = AutoModelForMaskedLM.from_pretrained("ccdv/lsg-xlm-roberta-base-4096", trust_remote_code=True)
100
+ tokenizer = AutoTokenizer.from_pretrained("ccdv/lsg-xlm-roberta-base-4096")
101
+
102
+ SENTENCES = ["Paris is the <mask> of France."]
103
+ pipeline = FillMaskPipeline(model, tokenizer)
104
+ output = pipeline(SENTENCES, top_k=1)
105
+
106
+ output = [o[0]["sequence"] for o in output]
107
+ > ['Paris is the capital of France.']
108
+ ```
109
+
110
+
111
+ Classification example:
112
+ ```python:
113
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
114
+
115
+ model = AutoModelForSequenceClassification.from_pretrained("ccdv/lsg-xlm-roberta-base-4096",
116
+ trust_remote_code=True,
117
+ pool_with_global=True, # pool with a global token instead of first token
118
+ )
119
+ tokenizer = AutoTokenizer.from_pretrained("ccdv/lsg-xlm-roberta-base-4096")
120
+
121
+ SENTENCE = "This is a test for sequence classification. " * 300
122
+ token_ids = tokenizer(
123
+ SENTENCE,
124
+ return_tensors="pt",
125
+ #pad_to_multiple_of=... # Optional
126
+ truncation=True
127
+ )
128
+ output = model(**token_ids)
129
+
130
+ > SequenceClassifierOutput(loss=None, logits=tensor([[-0.3051, -0.1762]], grad_fn=<AddmmBackward>), hidden_states=None, attentions=None)
131
+ ```
132
+
133
+ ## Training global tokens
134
+ To train global tokens and the classification head only:
135
+ ```python:
136
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
137
+
138
+ model = AutoModelForSequenceClassification.from_pretrained("ccdv/lsg-xlm-roberta-base-4096",
139
+ trust_remote_code=True,
140
+ pool_with_global=True, # pool with a global token instead of first token
141
+ num_global_tokens=16
142
+ )
143
+ tokenizer = AutoTokenizer.from_pretrained("ccdv/lsg-xlm-roberta-base-4096")
144
+
145
+ for name, param in model.named_parameters():
146
+ if "global_embeddings" not in name:
147
+ param.requires_grad = False
148
+ else:
149
+ param.required_grad = True
150
+ ```
attn.png ADDED
config.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "ccdv/lsg-xlm",
3
+ "adaptive": true,
4
+ "architectures": [
5
+ "LSGXLMRobertaForMaskedLM"
6
+ ],
7
+ "attention_probs_dropout_prob": 0.1,
8
+ "auto_map": {
9
+ "AutoConfig": "modeling_lsg_xlm_roberta.LSGXLMRobertaConfig",
10
+ "AutoModel": "modeling_lsg_xlm_roberta.LSGXLMRobertaModel",
11
+ "AutoModelForCausalLM": "modeling_lsg_xlm_roberta.LSGXLMRobertaForCausalLM",
12
+ "AutoModelForMaskedLM": "modeling_lsg_xlm_roberta.LSGXLMRobertaForMaskedLM",
13
+ "AutoModelForMultipleChoice": "modeling_lsg_xlm_roberta.LSGXLMRobertaForMultipleChoice",
14
+ "AutoModelForQuestionAnswering": "modeling_lsg_xlm_roberta.LSGXLMRobertaForQuestionAnswering",
15
+ "AutoModelForSequenceClassification": "modeling_lsg_xlm_roberta.LSGXLMRobertaForSequenceClassification",
16
+ "AutoModelForTokenClassification": "modeling_lsg_xlm_roberta.LSGXLMRobertaForTokenClassification"
17
+ },
18
+ "base_model_prefix": "lsg",
19
+ "block_size": 256,
20
+ "bos_token_id": 0,
21
+ "classifier_dropout": null,
22
+ "eos_token_id": 2,
23
+ "hidden_act": "gelu",
24
+ "hidden_dropout_prob": 0.1,
25
+ "hidden_size": 768,
26
+ "initializer_range": 0.02,
27
+ "intermediate_size": 3072,
28
+ "layer_norm_eps": 1e-05,
29
+ "lsh_num_pre_rounds": 1,
30
+ "mask_first_token": true,
31
+ "max_position_embeddings": 4098,
32
+ "model_type": "xlm-roberta",
33
+ "num_attention_heads": 12,
34
+ "num_global_tokens": 1,
35
+ "num_hidden_layers": 12,
36
+ "output_past": true,
37
+ "pad_token_id": 1,
38
+ "pool_with_global": true,
39
+ "position_embedding_type": "absolute",
40
+ "sparse_block_size": 0,
41
+ "sparsity_factor": 4,
42
+ "sparsity_type": "none",
43
+ "torch_dtype": "float32",
44
+ "transformers_version": "4.20.1",
45
+ "type_vocab_size": 1,
46
+ "use_cache": true,
47
+ "vocab_size": 250002
48
+ }
modeling_lsg_xlm_roberta.py ADDED
@@ -0,0 +1,1249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from logging import warn
2
+ from transformers.models.roberta.modeling_roberta import *
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
6
+ import sys
7
+
8
+ AUTO_MAP = {
9
+ "AutoModel": "modeling_lsg_xlm_roberta.LSGXLMRobertaModel",
10
+ "AutoModelForCausalLM": "modeling_lsg_xlm_roberta.LSGXLMRobertaForCausalLM",
11
+ "AutoModelForMaskedLM": "modeling_lsg_xlm_roberta.LSGXLMRobertaForMaskedLM",
12
+ "AutoModelForMultipleChoice": "modeling_lsg_xlm_roberta.LSGXLMRobertaForMultipleChoice",
13
+ "AutoModelForQuestionAnswering": "modeling_lsg_xlm_roberta.LSGXLMRobertaForQuestionAnswering",
14
+ "AutoModelForSequenceClassification": "modeling_lsg_xlm_roberta.LSGXLMRobertaForSequenceClassification",
15
+ "AutoModelForTokenClassification": "modeling_lsg_xlm_roberta.LSGXLMRobertaForTokenClassification"
16
+ }
17
+
18
+ class LSGXLMRobertaConfig(XLMRobertaConfig):
19
+ """
20
+ This class overrides :class:`~transformers.RobertaConfig`. Please check the superclass for the appropriate
21
+ documentation alongside usage examples.
22
+ """
23
+
24
+ base_model_prefix = "lsg"
25
+ model_type = "xlm-roberta"
26
+
27
+ def __init__(
28
+ self,
29
+ adaptive=True,
30
+ base_model_prefix="lsg",
31
+ block_size=128,
32
+ lsh_num_pre_rounds=1,
33
+ mask_first_token=False,
34
+ num_global_tokens=1,
35
+ pool_with_global=True,
36
+ sparse_block_size=128,
37
+ sparsity_factor=2,
38
+ sparsity_type="norm",
39
+ **kwargs
40
+ ):
41
+ """Constructs LSGXLMRobertaConfig."""
42
+ super().__init__(**kwargs)
43
+
44
+ self.adaptive = adaptive
45
+ self.auto_map = AUTO_MAP
46
+ self.base_model_prefix = base_model_prefix
47
+ self.block_size = block_size
48
+ self.lsh_num_pre_rounds = lsh_num_pre_rounds
49
+ self.mask_first_token = mask_first_token
50
+ self.num_global_tokens = num_global_tokens
51
+ self.pool_with_global = pool_with_global
52
+ self.sparse_block_size = sparse_block_size
53
+ self.sparsity_factor = sparsity_factor
54
+ self.sparsity_type = sparsity_type
55
+
56
+ if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride"]:
57
+ logger.warning(
58
+ "[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride'], setting sparsity_type=None, computation will skip sparse attention")
59
+ self.sparsity_type = None
60
+
61
+ if self.sparsity_type in ["stride", "block_stride"]:
62
+ if self.sparsity_factor > self.encoder_attention_heads:
63
+ logger.warning(
64
+ "[WARNING CONFIG]: sparsity_factor > encoder_attention_heads is not recommended for stride/block_stride sparsity"
65
+ )
66
+
67
+ if self.num_global_tokens < 1:
68
+ logger.warning(
69
+ "[WARNING CONFIG]: num_global_tokens < 1 is not compatible, setting num_global_tokens=1"
70
+ )
71
+ self.num_global_tokens = 1
72
+ elif self.num_global_tokens > 512:
73
+ logger.warning(
74
+ "[WARNING CONFIG]: num_global_tokens > 512 is not compatible, setting num_global_tokens=512"
75
+ )
76
+ self.num_global_tokens = 512
77
+
78
+ if self.sparsity_factor > 0:
79
+ assert self.block_size % self.sparsity_factor == 0, "[ERROR CONFIG]: block_size must be divisible by sparsity_factor"
80
+ assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
81
+
82
+
83
+ class BaseSelfAttention(nn.Module):
84
+
85
+ def init_modules(self, config):
86
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
87
+ config, "embedding_size"
88
+ ):
89
+ raise ValueError(
90
+ "The hidden size (%d) is not a multiple of the number of attention "
91
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
92
+ )
93
+
94
+ self.num_attention_heads = config.num_attention_heads
95
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
96
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
97
+
98
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
99
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
100
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
101
+
102
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
103
+
104
+ def transpose_for_scores(self, x):
105
+ new_x_shape = x.size()[:-1] + (
106
+ self.num_attention_heads,
107
+ self.attention_head_size,
108
+ )
109
+ x = x.view(*new_x_shape)
110
+ return x.permute(0, 2, 1, 3)
111
+
112
+ def reshape_output(self, context_layer):
113
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
114
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
115
+ return context_layer.view(*new_context_layer_shape)
116
+
117
+ def project_QKV(self, hidden_states):
118
+
119
+ query_layer = self.transpose_for_scores(self.query(hidden_states))
120
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
121
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
122
+ return query_layer, key_layer, value_layer
123
+
124
+
125
+ class BaseAttentionProduct(nn.Module):
126
+
127
+ def __init__(self, config):
128
+ """
129
+ Compute attention: softmax(Q @ K.T) @ V
130
+ """
131
+ super().__init__()
132
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
133
+
134
+ def forward(self, query_layer, key_layer, value_layer, attention_mask=None):
135
+
136
+ d = query_layer.shape[-1]
137
+
138
+ # Take the dot product between "query" and "key" to get the raw attention scores.
139
+ attention_scores = query_layer @ key_layer.transpose(-1, -2) / math.sqrt(d)
140
+
141
+ del query_layer
142
+ del key_layer
143
+
144
+ if attention_mask is not None:
145
+ # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)
146
+ attention_scores = attention_scores + attention_mask
147
+ del attention_mask
148
+
149
+ # Normalize the attention scores to probabilities.
150
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
151
+
152
+ # This is actually dropping out entire tokens to attend to, which might
153
+ # seem a bit unusual, but is taken from the original Transformer paper.
154
+ context_layer = self.dropout(attention_probs) @ value_layer
155
+
156
+ return context_layer
157
+
158
+
159
+ class CausalAttentionProduct(nn.Module):
160
+
161
+ def __init__(self, config):
162
+ """
163
+ Compute attention: softmax(Q @ K.T) @ V
164
+ """
165
+ super().__init__()
166
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
167
+ self.block_size = config.block_size
168
+
169
+ def forward(self, query_layer, key_layer, value_layer, attention_mask=None, causal_shape=None):
170
+
171
+ d = query_layer.shape[-1]
172
+
173
+ # Take the dot product between "query" and "key" to get the raw attention scores.
174
+ attention_scores = query_layer @ key_layer.transpose(-1, -2) / math.sqrt(d)
175
+
176
+ del query_layer
177
+ del key_layer
178
+
179
+ if attention_mask is not None:
180
+ # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)
181
+ attention_scores = attention_scores + attention_mask
182
+
183
+ # Add causal mask
184
+ causal_shape = (self.block_size, self.block_size) if causal_shape is None else causal_shape
185
+ causal_mask = torch.tril(
186
+ torch.ones(*causal_shape, device=attention_mask.device, dtype=attention_scores.dtype),
187
+ diagonal=-1
188
+ )
189
+ causal_mask = causal_mask.T * torch.finfo(attention_scores.dtype).min
190
+ attention_scores[..., -causal_shape[0]:, -causal_shape[1]:] = causal_mask
191
+
192
+ del attention_mask
193
+
194
+ # Normalize the attention scores to probabilities.
195
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
196
+
197
+ # This is actually dropping out entire tokens to attend to, which might
198
+ # seem a bit unusual, but is taken from the original Transformer paper.
199
+ context_layer = self.dropout(attention_probs) @ value_layer
200
+
201
+ return context_layer
202
+
203
+
204
+ class LSGAttentionProduct(nn.Module):
205
+
206
+ def __init__(self, config, block_size=None, sparse_block_size=None, sparsity_factor=4, is_causal=False):
207
+ """
208
+ Compute block or overlapping blocks attention products
209
+ """
210
+ super().__init__()
211
+
212
+ self.block_size = block_size
213
+ self.sparse_block_size = sparse_block_size
214
+ self.sparsity_factor = sparsity_factor
215
+ self.is_causal = is_causal
216
+
217
+ if self.block_size is None:
218
+ self.block_size = config.block_size
219
+
220
+ if self.sparse_block_size is None:
221
+ self.sparse_block_size = config.sparse_block_size
222
+
223
+ # Shape of blocks
224
+ self.local_shapes = (self.block_size*3, self.block_size)
225
+ if self.sparse_block_size and self.sparsity_factor > 0:
226
+ self.sparse_shapes = (self.sparse_block_size*3, self.block_size//self.sparsity_factor)
227
+
228
+ if is_causal:
229
+ self.attention = CausalAttentionProduct(config)
230
+ else:
231
+ self.attention = BaseAttentionProduct(config)
232
+
233
+ def build_lsg_inputs(self, hidden_states, sparse_hidden_states, global_hidden_states, is_attn_mask=False):
234
+
235
+ # Build local tokens
236
+ local_hidden_states = self.reshape_to_local_block(hidden_states, is_attn_mask)
237
+ del hidden_states
238
+
239
+ # Build sparse tokens
240
+ if sparse_hidden_states is not None:
241
+ sparse_hidden_states = self.reshape_to_sparse_block(sparse_hidden_states, is_attn_mask)
242
+
243
+ return self.cat_global_sparse_local_tokens(global_hidden_states, sparse_hidden_states, local_hidden_states)
244
+
245
+ def forward(
246
+ self,
247
+ query_layer,
248
+ key_layer,
249
+ value_layer,
250
+ attention_mask=None,
251
+ sparse_key=None,
252
+ sparse_value=None,
253
+ sparse_mask=None,
254
+ global_key=None,
255
+ global_value=None,
256
+ global_mask=None
257
+ ):
258
+
259
+ # Input batch, heads, length, hidden_size
260
+ n, h, t, d = query_layer.size()
261
+ n_blocks = t // self.block_size
262
+ assert t % self.block_size == 0
263
+
264
+ key_layer = self.build_lsg_inputs(
265
+ key_layer,
266
+ sparse_key,
267
+ global_key
268
+ )
269
+ del sparse_key
270
+ del global_key
271
+
272
+ value_layer = self.build_lsg_inputs(
273
+ value_layer,
274
+ sparse_value,
275
+ global_value
276
+ )
277
+ del sparse_value
278
+ del global_value
279
+
280
+ attention_mask = self.build_lsg_inputs(
281
+ attention_mask,
282
+ sparse_mask,
283
+ global_mask.transpose(-1, -2),
284
+ is_attn_mask=True
285
+ ).transpose(-1, -2)
286
+ del sparse_mask
287
+ del global_mask
288
+
289
+ # expect (..., t, d) shape
290
+ # Compute attention
291
+ context_layer = self.attention(
292
+ query_layer=self.chunk(query_layer, n_blocks),
293
+ key_layer=key_layer,
294
+ value_layer=value_layer,
295
+ attention_mask=attention_mask
296
+ )
297
+
298
+ return context_layer.reshape(n, h, -1, d)
299
+
300
+ def reshape_to_local_block(self, hidden_states, is_attn_mask=False):
301
+
302
+ size, step = self.local_shapes
303
+ s = (size - step) // 2
304
+
305
+ # Pad before block reshaping
306
+ if is_attn_mask:
307
+ pad_value = torch.finfo(hidden_states.dtype).min
308
+ hidden_states = hidden_states.transpose(-1, -2)
309
+ else:
310
+ pad_value = 0
311
+
312
+ hidden_states = torch.nn.functional.pad(
313
+ hidden_states.transpose(-1, -2),
314
+ pad=(s, s),
315
+ value=pad_value
316
+ ).transpose(-1, -2)
317
+
318
+ # Make blocks
319
+ hidden_states = hidden_states.unfold(-2, size=size, step=step).transpose(-1, -2)
320
+
321
+ # Skip third block if causal
322
+ if self.is_causal:
323
+ return hidden_states[..., :size*2//3, :]
324
+
325
+ return hidden_states
326
+
327
+ def reshape_to_sparse_block(self, hidden_states, is_attn_mask=False):
328
+
329
+ size, step = self.sparse_shapes
330
+
331
+ # In case of odd case
332
+ odd_offset = (step % 2)
333
+
334
+ # n, h, t, d*2 + 1
335
+ size = size*2
336
+ s = (size - step) // 2 + odd_offset
337
+
338
+ # Pad before block reshaping
339
+ if is_attn_mask:
340
+ pad_value = torch.finfo(hidden_states.dtype).min
341
+ hidden_states = hidden_states.transpose(-1, -2)
342
+ else:
343
+ pad_value = 0
344
+
345
+ hidden_states = torch.nn.functional.pad(
346
+ hidden_states.transpose(-1, -2),
347
+ pad=(s, s),
348
+ value=pad_value
349
+ ).transpose(-1, -2)
350
+
351
+ # Make blocks
352
+ hidden_states = hidden_states.unfold(-2, size=size, step=step).transpose(-1, -2)
353
+
354
+ # Fix case where block_size == sparsify_factor
355
+ if odd_offset:
356
+ hidden_states = hidden_states[..., :-1, :, :]
357
+
358
+ # Indexes for selection
359
+ u = (size - self.block_size * 3 // self.sparsity_factor) // 2 + odd_offset
360
+ s = self.sparse_block_size
361
+
362
+ # Skip right block if causal
363
+ if self.is_causal:
364
+ return hidden_states[..., u-s:u, :]
365
+
366
+ u_ = u + odd_offset
367
+ return torch.cat([hidden_states[..., u-s:u, :], hidden_states[..., -u_:-u_+s, :]], dim=-2)
368
+
369
+ def cat_global_sparse_local_tokens(self, x_global, x_sparse=None, x_local=None, dim=-2):
370
+
371
+ n, h, b, t, d = x_local.size()
372
+ x_global = x_global.unsqueeze(-3).expand(-1, -1, b, -1, -1)
373
+ if x_sparse is not None:
374
+ return torch.cat([x_global, x_sparse, x_local], dim=dim)
375
+ return torch.cat([x_global, x_local], dim=dim)
376
+
377
+ def chunk(self, x, n_blocks):
378
+
379
+ t, d = x.size()[-2:]
380
+ return x.reshape(*x.size()[:-2], n_blocks, -1, d)
381
+
382
+
383
+ class LSGRobertaEmbeddings(RobertaEmbeddings):
384
+
385
+ def __init__(self, config):
386
+ super().__init__(config)
387
+
388
+ self.num_global_tokens = config.num_global_tokens
389
+
390
+ # Hardcoded but partially trained
391
+ self.global_embeddings = nn.Embedding(512, embedding_dim=config.hidden_size, )
392
+
393
+ self.block_size = config.block_size
394
+
395
+ def forward(
396
+ self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
397
+ ):
398
+ if position_ids is None:
399
+ if input_ids is not None:
400
+ # Create the position ids from the input token ids. Any padded tokens remain padded.
401
+ position_ids = create_position_ids_from_input_ids(
402
+ input_ids, self.padding_idx, past_key_values_length
403
+ ).to(input_ids.device)
404
+ else:
405
+ position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
406
+
407
+ if input_ids is not None:
408
+ input_shape = input_ids.size()
409
+ else:
410
+ input_shape = inputs_embeds.size()[:-1]
411
+
412
+ seq_length = input_shape[-1]
413
+
414
+ if token_type_ids is None:
415
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
416
+
417
+ if inputs_embeds is None:
418
+ inputs_embeds = self.word_embeddings(input_ids)
419
+ token_type_embeddings = self.token_type_embeddings(token_type_ids[:, :seq_length])
420
+
421
+ embeddings = inputs_embeds + token_type_embeddings
422
+ if self.position_embedding_type == "absolute":
423
+ position_embeddings = self.position_embeddings(position_ids[:, :seq_length])
424
+ embeddings += position_embeddings
425
+
426
+ #if self.num_global_tokens < 0:
427
+ n, t, d = embeddings.size()
428
+
429
+ # Add global_tokens
430
+ indexes = torch.arange(self.num_global_tokens, device=embeddings.device).reshape(1, -1)
431
+ global_embeddings = self.global_embeddings(indexes)
432
+ embeddings = torch.cat([global_embeddings.expand(n, -1, d), embeddings], dim=-2)
433
+
434
+ embeddings = self.LayerNorm(embeddings)
435
+ embeddings = self.dropout(embeddings)
436
+ return embeddings
437
+
438
+
439
+ class LSGRobertaSelfOutput(RobertaSelfOutput):
440
+
441
+ def __init__(self, config):
442
+ super().__init__(config)
443
+
444
+
445
+ class LSGAttention(RobertaAttention):
446
+
447
+ def __init__(self, config):
448
+
449
+ nn.Module.__init__(self)
450
+
451
+ self.self = LSGSelfAttention(config)
452
+ self.output = LSGRobertaSelfOutput(config)
453
+ self.pruned_heads = set()
454
+
455
+
456
+ class LSGRobertaIntermediate(RobertaIntermediate):
457
+
458
+ def __init__(self, config):
459
+ super().__init__(config)
460
+
461
+
462
+ class LSGRobertaOutput(RobertaOutput):
463
+
464
+ def __init__(self, config):
465
+ super().__init__(config)
466
+
467
+
468
+ class LSGRobertaPooler(RobertaPooler):
469
+
470
+ def __init__(self, config):
471
+ super().__init__(config)
472
+
473
+
474
+ class LSGSelfAttention(BaseSelfAttention):
475
+ '''
476
+ Compute local attention with overlapping blocs
477
+ Use global attention for tokens with highest norm
478
+ '''
479
+ def __init__(self, config):
480
+ super().__init__()
481
+
482
+ self.init_modules(config)
483
+
484
+ self.block_size = config.block_size
485
+ self.sparse_block_size = config.sparse_block_size
486
+ self.num_global_tokens = config.num_global_tokens
487
+ self.sparsity_factor = config.sparsity_factor
488
+ self.is_causal = config.is_decoder
489
+ self.is_decoder = config.is_decoder
490
+
491
+ self.attention = LSGAttentionProduct(
492
+ config,
493
+ block_size=config.block_size,
494
+ sparse_block_size=config.sparse_block_size,
495
+ sparsity_factor=self.sparsity_factor,
496
+ is_causal=self.is_causal
497
+ )
498
+
499
+ if self.is_causal:
500
+ self.causal_attention = CausalAttentionProduct(config)
501
+ self.full_attention = BaseAttentionProduct(config)
502
+
503
+ sparse_functions = {
504
+ "norm": self.get_sparse_tokens_with_norm,
505
+ "pooling": self.get_sparse_tokens_with_pooling,
506
+ "lsh": self.get_sparse_tokens_with_lsh,
507
+ "stride": self.get_sparse_tokens_with_stride,
508
+ "block_stride": self.get_sparse_tokens_with_block_stride,
509
+ }
510
+
511
+ self.sparsity_type = config.sparsity_type
512
+ self.get_sparse_elements = sparse_functions.get(self.sparsity_type, lambda x, y, z: (None, None, None))
513
+
514
+ if config.sparsity_type == "lsh":
515
+ self.lsh_num_pre_rounds = config.lsh_num_pre_rounds
516
+
517
+ def get_sparse_tokens_with_norm(self, keys, values, mask):
518
+
519
+ if self.sparsity_factor == 1:
520
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
521
+
522
+ with torch.no_grad():
523
+
524
+ block_size = min(self.block_size, self.sparse_block_size)
525
+ key_norm = keys.detach().norm(dim=-1, keepdim=True)
526
+ key_norm = key_norm * ~mask.transpose(-1, -2).bool()
527
+ key_norm = self.chunk(key_norm, block_size)
528
+
529
+ n, h, b, t, d = key_norm.size()
530
+
531
+ idx = key_norm.argsort(dim=-2)
532
+ del key_norm
533
+ idx += (torch.arange(b, device=keys.device)*t).reshape(1, 1, b, 1, 1)
534
+
535
+ split = (t - block_size // self.sparsity_factor, block_size // self.sparsity_factor)
536
+ sparse_idx = idx.split(split, -2)[-1].reshape(n, h, -1, 1)
537
+
538
+ d = keys.size()[-1]
539
+ keys = keys.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
540
+ values = values.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
541
+ mask = mask.expand(-1, h, -1, -1).transpose(-1, -2).gather(dim=-2, index=sparse_idx).transpose(-1, -2)
542
+
543
+ return keys, values, mask
544
+
545
+ def get_sparse_tokens_with_pooling(self, keys, values, mask):
546
+
547
+ if self.sparsity_factor == 1:
548
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
549
+
550
+ keys = self.chunk(keys, self.sparsity_factor)
551
+ values = self.chunk(values, self.sparsity_factor)
552
+
553
+ n, h, b, t, d = keys.size()
554
+ mask = mask.reshape(n, 1, b, 1, t)
555
+ mask = ~mask.transpose(-1, -2).bool()
556
+
557
+ keys = keys * mask
558
+ values = values * mask
559
+
560
+ mask = mask.sum(dim=-2)
561
+ keys = keys.sum(dim=-2) / (mask + 1e-6)
562
+ values = values.sum(dim=-2) / (mask + 1e-6)
563
+
564
+ mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
565
+ return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
566
+
567
+ def get_sparse_tokens_with_stride(self, keys, values, mask):
568
+
569
+ if self.sparsity_factor == 1:
570
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
571
+
572
+ n, h, t, d = keys.size()
573
+ sparse_idx = torch.arange(t // self.sparsity_factor, device=keys.device) * self.sparsity_factor
574
+ sparse_idx = sparse_idx.reshape(1, 1, -1, 1) + (torch.arange(h, device=keys.device) % self.sparsity_factor).reshape(1, h, 1, 1)
575
+ sparse_idx = sparse_idx.expand(n, h, -1, 1)
576
+
577
+ keys = keys.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
578
+ values = values.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
579
+ mask = mask.expand(-1, h, -1, -1).transpose(-1, -2).gather(dim=-2, index=sparse_idx).transpose(-1, -2)
580
+
581
+ return keys, values, mask
582
+
583
+ def get_sparse_tokens_with_block_stride(self, keys, values, mask):
584
+
585
+ if self.sparsity_factor == 1:
586
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
587
+
588
+ n, h, t, d = keys.size()
589
+
590
+ t, b = self.block_size, t // self.block_size
591
+ sparse_idx = torch.arange(t // self.sparsity_factor, device=keys.device)
592
+ sparse_idx = sparse_idx.reshape(1, 1, 1, -1, 1) + torch.arange(h, device=keys.device).reshape(1, h, 1, 1, 1) * (t // self.sparsity_factor)
593
+ sparse_idx = (sparse_idx % t)
594
+ sparse_idx = sparse_idx + torch.arange(b, device=keys.device).reshape(1, 1, -1, 1, 1) * t
595
+ sparse_idx = sparse_idx.reshape(1, h, -1, 1).expand(n, h, -1, 1)
596
+
597
+ keys = keys.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
598
+ values = values.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
599
+ mask = mask.expand(-1, h, -1, -1).transpose(-1, -2).gather(dim=-2, index=sparse_idx).transpose(-1, -2)
600
+
601
+ return keys, values, mask
602
+
603
+ def get_sparse_tokens_with_lsh(self, keys, values, mask):
604
+
605
+ if self.sparsity_factor == 1:
606
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
607
+
608
+ block_size = min(self.block_size, self.sparse_block_size)
609
+ keys = self.chunk(keys, block_size)
610
+ values = self.chunk(values, block_size)
611
+
612
+ n, h, b, t, d = keys.size()
613
+ mask = mask.reshape(n, 1, b, 1, t)
614
+ mask = ~mask.transpose(-1, -2).bool()
615
+
616
+ keys = keys * mask
617
+ values = values * mask
618
+ mask = mask.expand(-1, h, -1, -1, -1).float()
619
+
620
+ extra_factor = 1
621
+
622
+ for _ in range(self.lsh_num_pre_rounds):
623
+ keys, values, mask = self.lsh_round(keys, values, mask, t*extra_factor)
624
+
625
+ keys, values, mask = self.lsh_round(keys, values, mask, t//self.sparsity_factor)
626
+ keys /= mask + 1e-8
627
+ values /= mask + 1e-8
628
+
629
+ mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
630
+
631
+ return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
632
+
633
+ def lsh_round(self, keys, values, mask, output_size):
634
+
635
+ with torch.no_grad():
636
+
637
+ n_hashes = output_size // 2
638
+ n, h, b, t, d = keys.size()
639
+ binary_mask = mask.clamp(0, 1)
640
+
641
+ indexes = (torch.nn.functional.normalize(keys, dim=-1) * binary_mask) @ torch.randn(1, h, 1, d, n_hashes, device=keys.device)
642
+ indexes = torch.cat([indexes, -indexes], dim=-1).argmax(dim=-1, keepdim=True)
643
+
644
+ n, h, b, t, d = keys.size()
645
+
646
+ x_ = torch.zeros(n, h, b, output_size, d, device=keys.device)
647
+ mask_ = torch.zeros(n, h, b, output_size, 1, device=keys.device)
648
+ keys = torch.scatter_add(x_, dim=-2, index=indexes.expand(-1, -1, -1, -1, d), src=keys)
649
+ values = torch.scatter_add(x_, dim=-2, index=indexes.expand(-1, -1, -1, -1, d), src=values)
650
+ mask = torch.scatter_add(mask_, dim=-2, index=indexes, src=mask)
651
+
652
+ return keys[..., :output_size, :], values[..., :output_size, :], mask[..., :output_size, :]
653
+
654
+ def forward(
655
+ self,
656
+ hidden_states,
657
+ attention_mask=None,
658
+ head_mask=None,
659
+ encoder_hidden_states=None,
660
+ encoder_attention_mask=None,
661
+ past_key_value=None,
662
+ output_attentions=False,
663
+ ):
664
+
665
+ query_layer = self.query(hidden_states)
666
+
667
+ # If this is instantiated as a cross-attention module, the keys
668
+ # and values come from an encoder; the attention mask needs to be
669
+ # such that the encoder's padding tokens are not attended to.
670
+ is_cross_attention = encoder_hidden_states is not None
671
+
672
+ if is_cross_attention and past_key_value is not None:
673
+ # reuse k,v, cross_attentions
674
+ key_layer = past_key_value[0]
675
+ value_layer = past_key_value[1]
676
+ attention_mask = encoder_attention_mask
677
+ elif is_cross_attention:
678
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
679
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
680
+ attention_mask = encoder_attention_mask
681
+ elif past_key_value is not None:
682
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
683
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
684
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
685
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
686
+ else:
687
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
688
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
689
+
690
+ query_layer = self.transpose_for_scores(query_layer)
691
+
692
+ if self.is_decoder:
693
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
694
+ # Further calls to cross_attention layer can then reuse all cross-attention
695
+ # key/value_states (first "if" case)
696
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
697
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
698
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
699
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
700
+ past_key_value = (key_layer, value_layer)
701
+
702
+ if is_cross_attention:
703
+ outputs = self.cross_attention_forward(
704
+ query_layer=query_layer,
705
+ key_layer=key_layer,
706
+ value_layer=value_layer,
707
+ attention_mask=attention_mask,
708
+ output_attentions=output_attentions
709
+ )
710
+ else:
711
+ outputs = self.causal_forward(
712
+ query_layer,
713
+ key_layer,
714
+ value_layer,
715
+ attention_mask=attention_mask,
716
+ output_attentions=output_attentions,
717
+ )
718
+
719
+ outputs = outputs + ((key_layer, value_layer),)
720
+
721
+ else:
722
+ outputs = self.not_causal_forward(
723
+ query_layer,
724
+ key_layer,
725
+ value_layer,
726
+ attention_mask=attention_mask,
727
+ output_attentions=output_attentions
728
+ )
729
+
730
+ #if head_mask is not None:
731
+ # outputs = (outputs[0] * head_mask[:, :, :1, :1], ) + outputs[1:]
732
+ return outputs
733
+
734
+ def causal_forward(
735
+ self,
736
+ query_layer,
737
+ key_layer,
738
+ value_layer,
739
+ attention_mask=None,
740
+ output_attentions=False,
741
+ ):
742
+
743
+ n, h, t, d = key_layer.size()
744
+
745
+ # Cat global mask
746
+ attention_mask = torch.nn.functional.pad(attention_mask, (self.num_global_tokens, 0), value=0)
747
+
748
+ # Split input into global tokens and other tokens
749
+ split = (self.num_global_tokens, t - self.num_global_tokens)
750
+ global_query, query_layer = query_layer.split(split, dim=-2)
751
+
752
+ # Use normal causal attention if local attention covers every tokens
753
+ if t <= 2 * self.block_size + self.num_global_tokens:
754
+ context_layer = self.causal_attention(
755
+ query_layer=query_layer,
756
+ key_layer=key_layer,
757
+ value_layer=value_layer,
758
+ attention_mask=attention_mask,
759
+ causal_shape=(t - self.num_global_tokens, t - self.num_global_tokens)
760
+ )
761
+
762
+ context_layer = torch.cat([global_query, context_layer], dim=-2)
763
+ return (self.reshape_output(context_layer), )
764
+
765
+ # Split K Q M on global and non global
766
+ global_key, key_layer = key_layer.split(split, dim=-2)
767
+ global_value, value_layer = value_layer.split(split, dim=-2)
768
+ global_mask, attention_mask = attention_mask.split(split, dim=-1)
769
+
770
+ n, h, t, d = key_layer.size()
771
+
772
+ # Get sparse idx
773
+ sparse_key, sparse_value, sparse_mask = (None, None, None)
774
+ if self.sparse_block_size and self.sparsity_factor > 0:
775
+ sparse_key, sparse_value, sparse_mask = self.get_sparse_elements(key_layer, value_layer, attention_mask)
776
+
777
+ # Expand masks on heads
778
+ attention_mask = attention_mask.expand(-1, h, -1, -1)
779
+ global_mask = global_mask.expand(-1, h, -1, -1)
780
+
781
+ # Compute dot product attention
782
+ context_layer = self.attention(
783
+ query_layer,
784
+ key_layer,
785
+ value_layer,
786
+ attention_mask,
787
+ sparse_key=sparse_key,
788
+ sparse_value=sparse_value,
789
+ sparse_mask=sparse_mask,
790
+ global_key=global_key,
791
+ global_value=global_value,
792
+ global_mask=global_mask
793
+ )
794
+
795
+ # Merge pseudo global (causal) and local-sparse tokens
796
+ context_layer = torch.cat([global_query, context_layer], dim=-2)
797
+ context_layer = self.reshape_output(context_layer)
798
+
799
+ return (context_layer,)
800
+
801
+ def not_causal_forward(
802
+ self,
803
+ query_layer,
804
+ key_layer,
805
+ value_layer,
806
+ attention_mask=None,
807
+ output_attentions=False,
808
+ ):
809
+
810
+ n, h, t, d = query_layer.size()
811
+
812
+ # Cat global mask
813
+ attention_mask = torch.nn.functional.pad(attention_mask, (self.num_global_tokens, 0), value=0)
814
+
815
+ # Use normal attention if local attention covers every tokens
816
+ if t <= 2 * self.block_size + self.num_global_tokens:
817
+ context_layer = self.full_attention(
818
+ query_layer=query_layer,
819
+ key_layer=key_layer,
820
+ value_layer=value_layer,
821
+ attention_mask=attention_mask
822
+ )
823
+ return (self.reshape_output(context_layer), )
824
+
825
+ # Split input into global tokens and other tokens
826
+ split = (self.num_global_tokens, t - self.num_global_tokens)
827
+ global_query, query_layer = query_layer.split(split, dim=-2)
828
+
829
+ # Get global_attention
830
+ bos = self.full_attention(
831
+ query_layer=global_query,
832
+ key_layer=key_layer,
833
+ value_layer=value_layer,
834
+ attention_mask=attention_mask
835
+ )
836
+
837
+ # Split K Q M on global and non global
838
+ global_key, key_layer = key_layer.split(split, dim=-2)
839
+ global_value, value_layer = value_layer.split(split, dim=-2)
840
+ global_mask, attention_mask = attention_mask.split(split, dim=-1)
841
+
842
+ n, h, t, d = key_layer.size()
843
+
844
+ # Get sparse idx
845
+ sparse_key, sparse_value, sparse_mask = (None, None, None)
846
+
847
+ if self.sparse_block_size and self.sparsity_factor > 0:
848
+ sparse_key, sparse_value, sparse_mask = self.get_sparse_elements(key_layer, value_layer, attention_mask)
849
+
850
+ # Expand masks on heads
851
+ attention_mask = attention_mask.expand(-1, h, -1, -1)
852
+ global_mask = global_mask.expand(-1, h, -1, -1)
853
+
854
+ # Compute dot product attention
855
+ context_layer = self.attention(
856
+ query_layer,
857
+ key_layer,
858
+ value_layer,
859
+ attention_mask,
860
+ sparse_key=sparse_key,
861
+ sparse_value=sparse_value,
862
+ sparse_mask=sparse_mask,
863
+ global_key=global_key,
864
+ global_value=global_value,
865
+ global_mask=global_mask
866
+ )
867
+
868
+ # Merge global and local-sparse tokens
869
+ context_layer = torch.cat([bos, context_layer], dim=-2)
870
+ context_layer = self.reshape_output(context_layer)
871
+
872
+ return (context_layer,)
873
+
874
+ def cross_attention_forward(
875
+ self,
876
+ query_layer,
877
+ key_layer,
878
+ value_layer,
879
+ attention_mask=None,
880
+ output_attentions=False,
881
+ ):
882
+
883
+ context_layer = self.full_attention(
884
+ query_layer=query_layer,
885
+ key_layer=key_layer,
886
+ value_layer=value_layer,
887
+ attention_mask=attention_mask
888
+ )
889
+ return (self.reshape_output(context_layer), )
890
+
891
+ def chunk(self, x, chunk_size):
892
+
893
+ n, h, t, d = x.size()
894
+ return x.reshape(n, h, -1, chunk_size, d)
895
+
896
+
897
+ class LSGRobertaLayer(RobertaLayer):
898
+
899
+ def __init__(self, config):
900
+
901
+ nn.Module.__init__(self)
902
+
903
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
904
+ self.seq_len_dim = 1
905
+ self.attention = LSGAttention(config)
906
+ self.is_decoder = config.is_decoder
907
+ self.add_cross_attention = config.add_cross_attention
908
+ if self.add_cross_attention:
909
+ assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added"
910
+ self.crossattention = LSGAttention(config)
911
+ self.intermediate = LSGRobertaIntermediate(config)
912
+ self.output = LSGRobertaOutput(config)
913
+
914
+
915
+ class LSGRobertaEncoder(RobertaEncoder):
916
+
917
+ def __init__(self, config):
918
+
919
+ nn.Module.__init__(self)
920
+
921
+ self.config = config
922
+ self.layer = nn.ModuleList([LSGRobertaLayer(config) for _ in range(config.num_hidden_layers)])
923
+ self.gradient_checkpointing = False
924
+
925
+
926
+ class LSGRobertaPreTrainedModel(RobertaPreTrainedModel):
927
+ """
928
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
929
+ models.
930
+ """
931
+
932
+ config_class = LSGXLMRobertaConfig
933
+
934
+ def _set_gradient_checkpointing(self, module, value=False):
935
+ if isinstance(module, (RobertaEncoder, LSGRobertaEncoder)):
936
+ module.gradient_checkpointing = value
937
+
938
+
939
+ class LSGXLMRobertaModel(LSGRobertaPreTrainedModel, RobertaModel):
940
+ """
941
+ This class overrides :class:`~transformers.RobertaModel`. Please check the superclass for the appropriate
942
+ documentation alongside usage examples.
943
+ """
944
+
945
+ config_class = LSGXLMRobertaConfig
946
+
947
+
948
+ def __init__(self, config, add_pooling_layer=False):
949
+
950
+ LSGRobertaPreTrainedModel.__init__(self, config)
951
+
952
+ assert hasattr(config, "num_global_tokens")
953
+ self.num_global_tokens = config.num_global_tokens
954
+ self.pad_idx = config.pad_token_id
955
+
956
+ assert hasattr(config, "block_size") and hasattr(config, "adaptive")
957
+ self.block_size = config.block_size
958
+ self.adaptive = config.adaptive
959
+ self.mask_first_token = config.mask_first_token
960
+ self.pool_with_global = config.pool_with_global
961
+
962
+ self.embeddings = LSGRobertaEmbeddings(config)
963
+ self.encoder = LSGRobertaEncoder(config)
964
+ self.pooler = LSGRobertaPooler(config) if add_pooling_layer else None
965
+
966
+ if config.add_cross_attention:
967
+ logger.warning(
968
+ "Cross attention is computed using full attention since it is not LSG compatible."
969
+ )
970
+
971
+ # Initialize weights and apply final processing
972
+ self.post_init()
973
+
974
+ def forward(
975
+ self,
976
+ input_ids=None,
977
+ attention_mask=None,
978
+ token_type_ids=None,
979
+ position_ids=None,
980
+ head_mask=None,
981
+ inputs_embeds=None,
982
+ encoder_hidden_states=None,
983
+ encoder_attention_mask=None,
984
+ past_key_values=None,
985
+ use_cache=None,
986
+ output_attentions=None,
987
+ output_hidden_states=None,
988
+ return_dict=None
989
+ ):
990
+
991
+ inputs_ = input_ids if input_ids is not None else inputs_embeds
992
+ n, t = inputs_.size()[:2]
993
+
994
+ if attention_mask is None:
995
+ attention_mask = torch.ones(n, t, device=inputs_.device, dtype=inputs_.dtype)
996
+ if self.mask_first_token:
997
+ attention_mask[:,0] = 0
998
+
999
+ b = self.block_size * 2
1000
+ pad = t % self.block_size
1001
+
1002
+ # Check if t is multiple of block_size and pad
1003
+ if self.adaptive and t > b and pad > 0:
1004
+ pad_length = self.block_size - pad
1005
+ if input_ids is not None:
1006
+ input_ids = torch.nn.functional.pad(input_ids, (0, pad_length), value=self.pad_idx)
1007
+ else:
1008
+ inputs_embeds = torch.nn.functional.pad(inputs_embeds.transpose(-1, -2), (0, pad_length), value=0.).transpose(-1, -2)
1009
+
1010
+ attention_mask = torch.nn.functional.pad(attention_mask, (0, pad_length), value=0)
1011
+
1012
+ if token_type_ids is not None:
1013
+ token_type_ids = torch.nn.functional.pad(token_type_ids, (0, pad_length), value=0)
1014
+ if position_ids is not None:
1015
+ position_ids = torch.nn.functional.pad(position_ids, (0, pad_length), value=0)
1016
+
1017
+ n, t_ = attention_mask.size()
1018
+
1019
+ encoder_outputs = super().forward(
1020
+ input_ids=input_ids,
1021
+ attention_mask=attention_mask,
1022
+ token_type_ids=token_type_ids,
1023
+ position_ids=position_ids,
1024
+ head_mask=head_mask,
1025
+ inputs_embeds=inputs_embeds,
1026
+ encoder_hidden_states=encoder_hidden_states,
1027
+ encoder_attention_mask=encoder_attention_mask,
1028
+ past_key_values=past_key_values,
1029
+ use_cache=use_cache,
1030
+ output_attentions=output_attentions,
1031
+ output_hidden_states=output_hidden_states,
1032
+ return_dict=return_dict
1033
+ )
1034
+
1035
+ context = encoder_outputs[0]
1036
+ if self.pool_with_global:
1037
+ context[:, self.num_global_tokens] = context[:, 0]
1038
+
1039
+ diff = t - t_
1040
+ n, _, d = context.size()
1041
+ context = context[..., self.num_global_tokens:, :]
1042
+
1043
+ # Adapt sequence to initial shape
1044
+ if diff < 0:
1045
+ context = context[:, :t]
1046
+
1047
+ encoder_outputs.last_hidden_state = context
1048
+ sequence_output = encoder_outputs[0]
1049
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
1050
+
1051
+ if not return_dict:
1052
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
1053
+
1054
+ return BaseModelOutputWithPoolingAndCrossAttentions(
1055
+ last_hidden_state=sequence_output,
1056
+ pooler_output=pooled_output,
1057
+ past_key_values=encoder_outputs.past_key_values,
1058
+ hidden_states=encoder_outputs.hidden_states,
1059
+ attentions=encoder_outputs.attentions,
1060
+ cross_attentions=encoder_outputs.cross_attentions,
1061
+ )
1062
+
1063
+ def get_extended_attention_mask(self, attention_mask, input_shape, device=None):
1064
+
1065
+ # Do not rely on original triangular mask from BERT/RoBERTa for causalLM
1066
+ if attention_mask.dim() == 3:
1067
+ extended_attention_mask = attention_mask[:, None, :, :]
1068
+ elif attention_mask.dim() == 2:
1069
+ extended_attention_mask = attention_mask[:, None, None, :]
1070
+ else:
1071
+ raise ValueError(
1072
+ f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})"
1073
+ )
1074
+
1075
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
1076
+ extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(extended_attention_mask.dtype).min
1077
+
1078
+ return extended_attention_mask
1079
+
1080
+
1081
+ class LSGXLMRobertaForCausalLM(LSGRobertaPreTrainedModel, RobertaForCausalLM):
1082
+
1083
+ config_class = LSGXLMRobertaConfig
1084
+ _keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
1085
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
1086
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1087
+
1088
+ def __init__(self, config):
1089
+
1090
+ LSGRobertaPreTrainedModel.__init__(self, config)
1091
+
1092
+ if not config.is_decoder:
1093
+ logger.warning("If you want to use `LSGRobertaLMHeadModel` as a standalone, add `is_decoder=True.`")
1094
+
1095
+ self.roberta = LSGXLMRobertaModel(config, add_pooling_layer=False)
1096
+ self.lm_head = LSGRobertaLMHead(config)
1097
+
1098
+ # The LM head weights require special treatment only when they are tied with the word embeddings
1099
+ self.update_keys_to_ignore(config, ["lm_head.decoder.weight"])
1100
+
1101
+ # Initialize weights and apply final processing
1102
+ self.post_init()
1103
+
1104
+
1105
+ class LSGXLMRobertaForMaskedLM(LSGRobertaPreTrainedModel, RobertaForMaskedLM):
1106
+ """
1107
+ This class overrides :class:`~transformers.RobertaForMaskedLM`. Please check the superclass for the appropriate
1108
+ documentation alongside usage examples.
1109
+ """
1110
+ config_class = LSGXLMRobertaConfig
1111
+ _keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
1112
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
1113
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1114
+
1115
+ def __init__(self, config):
1116
+
1117
+ LSGRobertaPreTrainedModel.__init__(self, config)
1118
+
1119
+ if config.is_decoder:
1120
+ logger.warning(
1121
+ "If you want to use `LSGRobertaForMaskedLM` make sure `config.is_decoder=False` for "
1122
+ "bi-directional self-attention."
1123
+ )
1124
+
1125
+ self.roberta = LSGXLMRobertaModel(config, add_pooling_layer=False)
1126
+ self.lm_head = LSGRobertaLMHead(config)
1127
+
1128
+ # The LM head weights require special treatment only when they are tied with the word embeddings
1129
+ self.update_keys_to_ignore(config, ["lm_head.decoder.weight"])
1130
+
1131
+ # Initialize weights and apply final processing
1132
+ self.post_init()
1133
+
1134
+
1135
+ class LSGRobertaLMHead(RobertaLMHead):
1136
+ """LSG Head for masked language modeling."""
1137
+
1138
+ def __init__(self, config):
1139
+ super().__init__(config)
1140
+
1141
+
1142
+ class LSGXLMRobertaForSequenceClassification(LSGRobertaPreTrainedModel, RobertaForSequenceClassification):
1143
+ """
1144
+ This class overrides :class:`~transformers.RobertaForSequenceClassification`. Please check the superclass for the
1145
+ appropriate documentation alongside usage examples.
1146
+ """
1147
+ config_class = LSGXLMRobertaConfig
1148
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
1149
+
1150
+ def __init__(self, config):
1151
+
1152
+ LSGRobertaPreTrainedModel.__init__(self, config)
1153
+
1154
+ self.num_labels = config.num_labels
1155
+ self.config = config
1156
+
1157
+ self.roberta = LSGXLMRobertaModel(config, add_pooling_layer=False)
1158
+ self.classifier = LSGRobertaClassificationHead(config)
1159
+
1160
+ # Initialize weights and apply final processing
1161
+ self.post_init()
1162
+
1163
+
1164
+ class LSGRobertaClassificationHead(RobertaClassificationHead):
1165
+ """Head for sentence-level classification tasks."""
1166
+
1167
+ def __init__(self, config):
1168
+ super().__init__(config)
1169
+
1170
+
1171
+ class LSGXLMRobertaForMultipleChoice(LSGRobertaPreTrainedModel, RobertaForMultipleChoice):
1172
+ """
1173
+ This class overrides :class:`~transformers.RobertaForMultipleChoice`. Please check the superclass for the
1174
+ appropriate documentation alongside usage examples.
1175
+ """
1176
+ config_class = LSGXLMRobertaConfig
1177
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
1178
+
1179
+ def __init__(self, config):
1180
+
1181
+ LSGRobertaPreTrainedModel.__init__(self, config)
1182
+
1183
+ self.roberta = LSGXLMRobertaModel(config)
1184
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1185
+ self.classifier = nn.Linear(config.hidden_size, 1)
1186
+
1187
+ # Initialize weights and apply final processing
1188
+ self.post_init()
1189
+
1190
+
1191
+ class LSGXLMRobertaForTokenClassification(LSGRobertaPreTrainedModel, RobertaForTokenClassification):
1192
+ """
1193
+ This class overrides :class:`~transformers.RobertaForTokenClassification`. Please check the superclass for the
1194
+ appropriate documentation alongside usage examples.
1195
+ """
1196
+ config_class = LSGXLMRobertaConfig
1197
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1198
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
1199
+
1200
+ def __init__(self, config):
1201
+
1202
+ LSGRobertaPreTrainedModel.__init__(self, config)
1203
+
1204
+ self.num_labels = config.num_labels
1205
+
1206
+ self.roberta = LSGXLMRobertaModel(config, add_pooling_layer=False)
1207
+ classifier_dropout = (
1208
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1209
+ )
1210
+ self.dropout = nn.Dropout(classifier_dropout)
1211
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1212
+
1213
+ # Initialize weights and apply final processing
1214
+ self.post_init()
1215
+
1216
+
1217
+ class LSGXLMRobertaForQuestionAnswering(LSGRobertaPreTrainedModel, RobertaForQuestionAnswering):
1218
+ """
1219
+ This class overrides :class:`~transformers.RobertaForQuestionAnswering`. Please check the superclass for the
1220
+ appropriate documentation alongside usage examples.
1221
+ """
1222
+ config_class = LSGXLMRobertaConfig
1223
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1224
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
1225
+
1226
+ def __init__(self, config):
1227
+
1228
+ LSGRobertaPreTrainedModel.__init__(self, config)
1229
+
1230
+ self.num_labels = config.num_labels
1231
+
1232
+ self.roberta = LSGXLMRobertaModel(config, add_pooling_layer=False)
1233
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1234
+
1235
+ # Initialize weights and apply final processing
1236
+ self.post_init()
1237
+
1238
+
1239
+ def str_to_class(classname):
1240
+ return getattr(sys.modules[__name__], classname)
1241
+
1242
+ # Register model in Auto API
1243
+ try:
1244
+ LSGXLMRobertaConfig.register_for_auto_class()
1245
+ for key, value in AUTO_MAP.items():
1246
+ str_to_class(value.split(".")[-1]).register_for_auto_class(key)
1247
+ except:
1248
+ warn("AutoRegister isn't available, you'll have to manually copy modeling.py after .save_pretrained(...).")
1249
+ warn("Update to transformers >= 4.17.0 to fix.")
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7f35de8bd314a5ffaa17d445d4158766e5552c96a4d1fbb222379ff45803ddd
3
+ size 1125861865
sentencepiece.bpe.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cfc8146abe2a0488e9e2a0c56de7952f7c11ab059eca145a0a727afce0db2865
3
+ size 5069051
special_tokens_map.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<s>",
3
+ "cls_token": "<s>",
4
+ "eos_token": "</s>",
5
+ "mask_token": {
6
+ "content": "<mask>",
7
+ "lstrip": true,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false
11
+ },
12
+ "pad_token": "<pad>",
13
+ "sep_token": "</s>",
14
+ "unk_token": "<unk>"
15
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:62c24cdc13d4c9952d63718d6c9fa4c287974249e16b7ade6d5a85e7bbb75626
3
+ size 17082660
tokenizer_config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<s>",
3
+ "cls_token": "<s>",
4
+ "eos_token": "</s>",
5
+ "mask_token": {
6
+ "__type": "AddedToken",
7
+ "content": "<mask>",
8
+ "lstrip": true,
9
+ "normalized": true,
10
+ "rstrip": false,
11
+ "single_word": false
12
+ },
13
+ "model_max_length": 4096,
14
+ "name_or_path": "xlm-roberta-base",
15
+ "pad_token": "<pad>",
16
+ "sep_token": "</s>",
17
+ "special_tokens_map_file": null,
18
+ "tokenizer_class": "XLMRobertaTokenizer",
19
+ "unk_token": "<unk>"
20
+ }