BlackSamorez commited on
Commit
22d23ba
·
1 Parent(s): 865b874

initial implementation

Browse files
Files changed (5) hide show
  1. config.json +27 -0
  2. configuration_yalm.py +119 -0
  3. modeling_yalm.py +1083 -0
  4. spiece.model +3 -0
  5. tokenization_yalm.py +250 -0
config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "RWForCausalLM"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_yalm.YalmConfig",
8
+ "AutoModel": "modelling_yalm.YalmModel",
9
+ "AutoModelForSequenceClassification": "modelling_yalm.RWForSequenceClassification",
10
+ "AutoModelForCausalLM": "modelling_yalm.YalmForCausalLM"
11
+ },
12
+ "padded_vocab_size": 128000,
13
+ "embedding_size": 2048,
14
+ "hidden_size": 10240,
15
+ "intermediate_size": 27308,
16
+ "num_layers": 80,
17
+ "num_attention_heads": 128,
18
+ "scale_attn_by_inverse_layer_idx": true,
19
+ "activation_type": "geglu",
20
+ "model_type": "YaLM",
21
+ "max_position_embeddings": 1024,
22
+ "apply_residual_connection_post_layernorm": false,
23
+ "initializer_range": 0.02,
24
+ "layernorm_epsilon": 1e-5,
25
+ "torch_dtype": "float16",
26
+ "transformers_version": "4.32.1"
27
+ }
configuration_yalm.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on Yandex's YaLM-100B library and the LLaMA
5
+ # implementations in transformers library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to LLaMA used by the Yandex team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """YaLM model configuration"""
21
+
22
+ from transformers.configuration_utils import PretrainedConfig
23
+ from transformers.utils import logging
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+ YALM_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
28
+
29
+
30
+ class YalmConfig(PretrainedConfig):
31
+ r"""
32
+ This is the configuration class to store the configuration of a [`YalmModel`]. It is used to instantiate an YaLM
33
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
34
+ defaults will yield a similar configuration to that of the YaLM-100B.
35
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
36
+ documentation from [`PretrainedConfig`] for more information.
37
+ Args:
38
+ padded_vocab_size (`int`, *optional*, defaults to 128000):
39
+ Vocabulary size of the YaLM model. Defines the number of different tokens that can be represented by the
40
+ `inputs_ids` passed when calling [`YalmModel`]
41
+ embedding_size (`int`, *optional*, defaults to 2048):
42
+ Token embeding dimension
43
+ hidden_size (`int`, *optional*, defaults to 10240):
44
+ Dimension of the hidden representations.
45
+ intermediate_size (`int`, *optional*, defaults to 27308):
46
+ Dimension of the MLP representations.
47
+ num_layers (`int`, *optional*, defaults to 80):
48
+ Number of hidden layers in the Transformer encoder.
49
+ num_attention_heads (`int`, *optional*, defaults to 128):
50
+ Number of attention heads for each attention layer in the Transformer encoder.
51
+ scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to True):
52
+ Whether to scale attention output by inverse layer depth
53
+ activation_type (`str` or `function`, *optional*, defaults to `"geglu"`):
54
+ The non-linear activation function (function or string) in the decoder.
55
+ max_position_embeddings (`int`, *optional*, defaults to 1024):
56
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
57
+ just in case (e.g., 512 or 1024 or 2048).
58
+ apply_residual_connection_post_layernorm (`bool`, *optional*, defaults to `False`):
59
+ If enabled, use the layer norm of the hidden states as the residual in the transformer blocks
60
+ initializer_range (`float`, *optional*, defaults to 0.02):
61
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
62
+ layernorm_epsilon (`float`, *optional*, defaults to 1e-12):
63
+ The epsilon used by the layer normalization layers.
64
+ use_cache (`bool`, *optional*, defaults to `True`):
65
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
66
+ relevant if `config.is_decoder=True`.
67
+ Example:
68
+ ```python
69
+ >>> from transformers import YalmModel, YalmConfig
70
+ >>> # Initializing a YaLM yalm-100b style configuration
71
+ >>> configuration = YalmConfig()
72
+ >>> # Initializing a model from the yalm-100b style configuration
73
+ >>> model = YalmModel(configuration)
74
+ >>> # Accessing the model configuration
75
+ >>> configuration = model.config
76
+ ```"""
77
+ model_type = "yalm"
78
+
79
+ def __init__(
80
+ self,
81
+ padded_vocab_size=128000,
82
+ embedding_size=2048,
83
+ hidden_size=10240,
84
+ intermediate_size=27308,
85
+ num_layers=80,
86
+ num_attention_heads=128,
87
+ scale_attn_by_inverse_layer_idx=True,
88
+ activation_type="geglu",
89
+ max_position_embeddings=1024,
90
+ apply_residual_connection_post_layernorm=False,
91
+ initializer_range=0.02,
92
+ layernorm_epsilon=1e-5,
93
+ attention_dropout=0.1,
94
+ hidden_dropout=0.1,
95
+ use_cache=True,
96
+ bos_token_id=1,
97
+ eos_token_id=2,
98
+ **kwargs,
99
+ ):
100
+ self.padded_vocab_size = padded_vocab_size
101
+ self.embedding_size = embedding_size
102
+ self.hidden_size = hidden_size
103
+ self.intermediate_size = intermediate_size
104
+ self.num_layers = num_layers
105
+ self.num_attention_heads = num_attention_heads
106
+ self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx
107
+ self.activation_type = activation_type
108
+ self.max_position_embeddings = max_position_embeddings
109
+ self.apply_residual_connection_post_layernorm = False
110
+ self.initializer_range = initializer_range
111
+ self.layernorm_epsilon = layernorm_epsilon
112
+ self.attention_dropout = attention_dropout
113
+ self.hidden_dropout = hidden_dropout
114
+ self.use_cache = use_cache
115
+ super().__init__(
116
+ bos_token_id=bos_token_id,
117
+ eos_token_id=eos_token_id,
118
+ **kwargs,
119
+ )
modeling_yalm.py ADDED
@@ -0,0 +1,1083 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on Yandex's YaLM-100B library and the LLaMA
5
+ # implementations in transformers library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to LLaMA used by the Yandex team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ PyTorch YaLM model."""
21
+ import math
22
+ from typing import List, Optional, Tuple, Union
23
+
24
+ import torch
25
+ import torch.utils.checkpoint
26
+ from torch import nn
27
+ from torch.nn import CrossEntropyLoss
28
+ from transformers.modeling_outputs import (BaseModelOutputWithPast,
29
+ CausalLMOutputWithPast)
30
+ from transformers.modeling_utils import PreTrainedModel
31
+ from transformers.utils import (add_start_docstrings,
32
+ add_start_docstrings_to_model_forward, logging,
33
+ replace_return_docstrings)
34
+
35
+ from configuration_yalm import YalmConfig
36
+
37
+ logger = logging.get_logger(__name__)
38
+
39
+ _CONFIG_FOR_DOC = "YalmConfig"
40
+
41
+
42
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
43
+ def _make_causal_mask(
44
+ input_ids_shape: torch.Size,
45
+ dtype: torch.dtype,
46
+ device: torch.device,
47
+ past_key_values_length: int = 0,
48
+ ):
49
+ """
50
+ Make causal mask used for bi-directional self-attention.
51
+ """
52
+ bsz, tgt_len = input_ids_shape
53
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
54
+ mask_cond = torch.arange(mask.size(-1), device=device)
55
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
56
+ mask = mask.to(dtype)
57
+
58
+ if past_key_values_length > 0:
59
+ mask = torch.cat(
60
+ [
61
+ torch.zeros(
62
+ tgt_len, past_key_values_length, dtype=dtype, device=device
63
+ ),
64
+ mask,
65
+ ],
66
+ dim=-1,
67
+ )
68
+ return mask[None, None, :, :].expand(
69
+ bsz, 1, tgt_len, tgt_len + past_key_values_length
70
+ )
71
+
72
+
73
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
74
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
75
+ """
76
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
77
+ """
78
+ bsz, src_len = mask.size()
79
+ tgt_len = tgt_len if tgt_len is not None else src_len
80
+
81
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
82
+
83
+ inverted_mask = 1.0 - expanded_mask
84
+
85
+ return inverted_mask.masked_fill(
86
+ inverted_mask.to(torch.bool), torch.finfo(dtype).min
87
+ )
88
+
89
+
90
+ class YalmRotaryPositionEncoding(nn.Module):
91
+ def __init__(self, max_seq_length: int, hidden_size_per_attention_head: int, dtype):
92
+ super().__init__()
93
+ cos_cached, sin_cached = YalmRotaryPositionEncoding.get_cache_multipliers(
94
+ max_seq_length, hidden_size_per_attention_head, dtype
95
+ )
96
+ self.register_buffer(
97
+ "cos_cached", cos_cached.unsqueeze(1).unsqueeze(2), persistent=False
98
+ )
99
+ self.register_buffer(
100
+ "sin_cached", sin_cached.unsqueeze(1).unsqueeze(2), persistent=False
101
+ )
102
+
103
+ def forward(self, hidden_state, context_position):
104
+ seq_length = hidden_state.shape[0]
105
+ cache_slice = slice(context_position, context_position + seq_length)
106
+ return self.apply_rotary_position_encoding(
107
+ hidden_state, self.cos_cached[cache_slice], self.sin_cached[cache_slice]
108
+ )
109
+
110
+ @staticmethod
111
+ def get_cache_multipliers(max_seq_length, hidden_size, dtype):
112
+ inv_freqs = 1e-4 ** (
113
+ torch.arange(0, hidden_size, 2, dtype=torch.float) / hidden_size
114
+ )
115
+ positions = torch.arange(max_seq_length, dtype=torch.float)
116
+ angles = positions.unsqueeze(-1) * inv_freqs
117
+
118
+ return torch.cos(angles).to(dtype), torch.sin(angles).to(dtype)
119
+
120
+ @staticmethod
121
+ def apply_rotary_position_encoding(hidden_state, cos_cached, sin_cached):
122
+ sq, b, np, hn = hidden_state.shape
123
+ half_hn = hn // 2
124
+ left, right = hidden_state[..., :half_hn], hidden_state[..., half_hn:]
125
+ encoded_left = cos_cached * left - sin_cached * right
126
+ encoded_right = sin_cached * left + cos_cached * right
127
+ return torch.cat((encoded_left, encoded_right), dim=3)
128
+
129
+
130
+ class YalmSelfAttention(nn.Module):
131
+ """Parallel self-attention layer abstract class.
132
+
133
+ Self-attention layer takes input with size [b, s, h]
134
+ and returns output of the same size.
135
+ """
136
+
137
+ def __init__(self, config: YalmConfig, layer_idx: int):
138
+ super().__init__()
139
+
140
+ self.attention_mask_func = None
141
+ self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
142
+ self.layer_idx = layer_idx
143
+
144
+ # Per attention head and per partition values.
145
+ self.hidden_size_per_partition = config.hidden_size
146
+ self.num_attention_heads = config.num_attention_heads
147
+ self.hidden_size_per_attention_head = (
148
+ config.hidden_size // config.num_attention_heads
149
+ )
150
+
151
+ if (
152
+ self.hidden_size_per_attention_head * self.num_attention_heads
153
+ ) != self.hidden_size_per_partition:
154
+ raise ValueError(
155
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
156
+ f" and `num_heads`: {self.num_heads})."
157
+ )
158
+
159
+ self.num_attention_heads_per_partition = config.num_attention_heads
160
+
161
+ self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size)
162
+
163
+ self.coeff = None
164
+ self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
165
+ if self.scale_attn_by_inverse_layer_idx:
166
+ self.coeff = self.layer_idx + 1
167
+ self.norm_factor *= self.coeff
168
+
169
+ # Dropout. Note that for a single iteration, this layer will generate
170
+ # different outputs on different number of parallel partitions but
171
+ # on average it should not be partition dependent.
172
+ self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
173
+
174
+ # Output.
175
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
176
+
177
+ self.rotary_position_encoding = YalmRotaryPositionEncoding(
178
+ config.max_position_embeddings,
179
+ self.hidden_size_per_attention_head,
180
+ dtype=self.dense.weight.dtype,
181
+ )
182
+
183
+ def _transpose_last_dim(self, mixed_layer, num_splits, num_splits_first):
184
+ input_shape = mixed_layer.size()
185
+ if num_splits_first:
186
+ """[s, b, num_splits * np * hn]
187
+ -->(view) [s, b, num_splits, np, hn]
188
+ -->(tranpose) [s, b, np, num_splits, hn]
189
+ -->(view) [s, b, np * num_splits * hn]"""
190
+
191
+ intermediate_shape = input_shape[:-1] + (
192
+ num_splits,
193
+ self.num_attention_heads_per_partition,
194
+ self.hidden_size_per_attention_head,
195
+ )
196
+
197
+ mixed_layer = mixed_layer.view(*intermediate_shape)
198
+ mixed_layer = mixed_layer.transpose(-2, -3).contiguous()
199
+ else:
200
+ """[s, b, np * hn * num_splits]
201
+ -->(view) [s, b, np, hn, num_splits]
202
+ -->(tranpose) [s, b, np, num_splits, hn]
203
+ -->(view) [s, b, np * num_splits * hn]"""
204
+
205
+ intermediate_shape = input_shape[:-1] + (
206
+ self.num_attention_heads_per_partition,
207
+ self.hidden_size_per_attention_head,
208
+ num_splits,
209
+ )
210
+
211
+ mixed_layer = mixed_layer.view(*intermediate_shape)
212
+ mixed_layer = mixed_layer.transpose(-1, -2).contiguous()
213
+ mixed_layer = mixed_layer.view(*input_shape)
214
+
215
+ return mixed_layer
216
+
217
+ def forward(
218
+ self,
219
+ hidden_states: torch.FloatTensor,
220
+ attention_mask: torch.FloatTensor,
221
+ layer_past: Optional[Tuple[torch.Tensor, int]] = None,
222
+ use_cache: Optional[bool] = False,
223
+ output_attentions: Optional[bool] = False,
224
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
225
+ # hidden_states: [sq, b, h]
226
+
227
+ # =====================
228
+ # Query, Key, and Value
229
+ # =====================
230
+
231
+ # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
232
+ mixed_x_layer = self.query_key_value(hidden_states)
233
+
234
+ # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
235
+ new_tensor_shape = mixed_x_layer.size()[:-1] + (
236
+ self.num_attention_heads_per_partition,
237
+ 3 * self.hidden_size_per_attention_head,
238
+ )
239
+ mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
240
+
241
+ # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
242
+ (query_layer, key_layer, value_layer) = torch.split(
243
+ mixed_x_layer, self.hidden_size_per_attention_head, dim=-1
244
+ )
245
+
246
+ context_position = 0 if layer_past is None else layer_past[2]
247
+ query_layer = self.rotary_position_encoding(query_layer, context_position)
248
+ key_layer = self.rotary_position_encoding(key_layer, context_position)
249
+
250
+ # ==================================
251
+ # Adjust key and value for inference
252
+ # ==================================
253
+
254
+ if layer_past is not None:
255
+ past_key, past_value, sq_length = layer_past
256
+ key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=0)
257
+ value_layer = torch.cat(
258
+ (past_value.type_as(value_layer), value_layer), dim=0
259
+ )
260
+ sq_length += 1
261
+ else:
262
+ sq_length = key_layer.size()[0]
263
+
264
+ present = (key_layer, value_layer, sq_length) if use_cache else None
265
+
266
+ # ===================================
267
+ # Raw attention scores. [b, np, s, s]
268
+ # ===================================
269
+
270
+ # [b, np, sq, sk]
271
+ output_size = (
272
+ query_layer.size(1),
273
+ query_layer.size(2),
274
+ query_layer.size(0),
275
+ key_layer.size(0),
276
+ )
277
+
278
+ # [sq, b, np, hn] -> [sq, b * np, hn]
279
+ query_layer = query_layer.view(
280
+ output_size[2], output_size[0] * output_size[1], -1
281
+ )
282
+ key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
283
+
284
+ # preallocting result tensor: [b * np, sq, sk]
285
+ matmul_result = torch.empty(
286
+ output_size[0] * output_size[1],
287
+ output_size[2],
288
+ output_size[3],
289
+ dtype=query_layer.dtype,
290
+ device=query_layer.device,
291
+ )
292
+
293
+ # Raw attention scores. [b * np, sq, sk]
294
+ matmul_result = torch.baddbmm(
295
+ matmul_result,
296
+ query_layer.transpose(0, 1), # [b * np, sq, hn]
297
+ key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
298
+ beta=0.0,
299
+ alpha=(1.0 / self.norm_factor),
300
+ )
301
+
302
+ # change view to [b, np, sq, sk]
303
+ attention_scores = matmul_result.view(*output_size)
304
+
305
+ # ==================================================
306
+ # Update attention mask for inference. [b, np, sq, sk]
307
+ # ==================================================
308
+
309
+ # if attention_mask is not None:
310
+ # if layer_past is not None:
311
+ # attention_mask = attention_mask[
312
+ # ..., attention_scores.size(3) - 1, : attention_scores.size(3)
313
+ # ].unsqueeze(2)
314
+ # else:
315
+ # attention_mask = attention_mask[
316
+ # ..., : attention_scores.size(3), : attention_scores.size(3)
317
+ # ]
318
+
319
+ # ===========================
320
+ # Attention probs and dropout
321
+ # ===========================
322
+
323
+ # attention scores and attention mask [b, np, sq, sk]
324
+ if self.coeff is not None:
325
+ attention_scores = attention_scores * self.coeff
326
+ if attention_mask is not None:
327
+ attention_scores += attention_mask
328
+ attention_probs = torch.nn.Softmax(dim=-1)(attention_scores)
329
+
330
+ # attention_probs = self.attention_dropout(attention_probs) # TODO: why the fuck no scale???
331
+
332
+ # =========================
333
+ # Context layer. [sq, b, hp]
334
+ # =========================
335
+
336
+ # value_layer -> context layer.
337
+ # [sk, b, np, hn] --> [b, np, sq, hn]
338
+
339
+ # context layer shape: [b, np, sq, hn]
340
+ output_size = (
341
+ value_layer.size(1),
342
+ value_layer.size(2),
343
+ query_layer.size(0),
344
+ value_layer.size(3),
345
+ )
346
+
347
+ # change view [sk, b * np, hn]
348
+ value_layer = value_layer.view(
349
+ value_layer.size(0), output_size[0] * output_size[1], -1
350
+ )
351
+
352
+ # change view [b * np, sq, sk]
353
+ attention_probs = attention_probs.view(
354
+ output_size[0] * output_size[1], output_size[2], -1
355
+ )
356
+
357
+ # matmul: [b * np, sq, hn]
358
+ context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
359
+
360
+ # change view [b, np, sq, hn]
361
+ context_layer = context_layer.view(*output_size)
362
+
363
+ # [b, np, sq, hn] --> [sq, b, np, hn]
364
+ context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
365
+
366
+ # [sq, b, np, hn] --> [sq, b, hp]
367
+ new_context_layer_shape = context_layer.size()[:-2] + (
368
+ self.hidden_size_per_partition,
369
+ )
370
+ context_layer = context_layer.view(*new_context_layer_shape)
371
+
372
+ # =================
373
+ # Output. [sq, b, h]
374
+ # =================
375
+
376
+ output = self.dense(context_layer)
377
+ output = (output, present)
378
+ if output_attentions:
379
+ outputs += (attention_probs,)
380
+
381
+ return output
382
+
383
+
384
+ class YalmMLP(nn.Module):
385
+ """MLP.
386
+
387
+ MLP will take the input with h hidden state, project it to 4*h
388
+ hidden dimension, perform nonlinear transformation, and project the
389
+ state back into h hidden dimension. At the end, dropout is also
390
+ applied.
391
+ """
392
+
393
+ def __init__(self, config: YalmConfig):
394
+ super().__init__()
395
+
396
+ self.dense_ffn_hidden = nn.Linear(
397
+ config.hidden_size,
398
+ config.intermediate_size,
399
+ )
400
+
401
+ self.activation_type = config.activation_type
402
+ self.is_gated = config.activation_type in ["geglu"]
403
+
404
+ self.activation_func = torch.nn.functional.gelu
405
+
406
+ if self.is_gated:
407
+ self.dense_ffn_gate = nn.Linear(
408
+ config.hidden_size,
409
+ config.intermediate_size,
410
+ )
411
+
412
+ self.dense_ffn_output = nn.Linear(
413
+ config.intermediate_size,
414
+ config.hidden_size,
415
+ )
416
+
417
+ def forward(self, hidden_states):
418
+ intermediate_parallel = self.dense_ffn_hidden(hidden_states)
419
+
420
+ intermediate_parallel = self.activation_func(intermediate_parallel)
421
+
422
+ if self.is_gated:
423
+ gate = self.dense_ffn_gate(hidden_states)
424
+ intermediate_gated = intermediate_parallel * gate
425
+ else:
426
+ intermediate_gated = intermediate_parallel
427
+
428
+ output = self.dense_ffn_output(intermediate_gated)
429
+ return output
430
+
431
+
432
+ class YalmTransformerLayer(nn.Module):
433
+ """A single transformer layer.
434
+
435
+ Transformore layer takes input with size [b, s, h] and returns an
436
+ output of the same size.
437
+ """
438
+
439
+ def __init__(self, config: YalmConfig, layer_idx: int):
440
+ super().__init__()
441
+ self.layer_idx = layer_idx
442
+
443
+ self.apply_residual_connection_post_layernorm = (
444
+ config.apply_residual_connection_post_layernorm
445
+ )
446
+
447
+ # Layernorm on the input data.
448
+ if self.layer_idx > 0:
449
+ self.input_layernorm = nn.LayerNorm(
450
+ config.hidden_size,
451
+ eps=config.layernorm_epsilon,
452
+ )
453
+
454
+ # Self attention.
455
+ self.attention = YalmSelfAttention(config, layer_idx)
456
+ self.hidden_dropout = config.hidden_dropout
457
+
458
+ # Layernorm on the input data.
459
+ self.post_attention_layernorm = nn.LayerNorm(
460
+ config.hidden_size, eps=config.layernorm_epsilon
461
+ )
462
+
463
+ # MLP
464
+ self.mlp = YalmMLP(config)
465
+
466
+ def forward(
467
+ self,
468
+ hidden_states: Optional[torch.FloatTensor],
469
+ attention_mask: Optional[torch.FloatTensor] = None,
470
+ layer_past: Optional[Tuple[torch.Tensor, int]] = None,
471
+ use_cache: Optional[bool] = False,
472
+ output_attentions: Optional[bool] = False,
473
+ ):
474
+ # hidden_states: [b, s, h]
475
+
476
+ # Layer norm at the begining of the transformer layer.
477
+ if self.layer_idx > 0:
478
+ attention_input = self.input_layernorm(hidden_states)
479
+ else:
480
+ attention_input = hidden_states
481
+
482
+ # Self attention.
483
+ attention_layer_outputs = self.attention(
484
+ attention_input,
485
+ attention_mask,
486
+ layer_past=layer_past,
487
+ use_cache=use_cache,
488
+ output_attentions=output_attentions,
489
+ )
490
+ attention_output = attention_layer_outputs[
491
+ 0
492
+ ] # output_attn: attention_output, present, (attn_weights)
493
+ outputs = attention_layer_outputs[1:]
494
+
495
+ # Residual connection.
496
+ if self.apply_residual_connection_post_layernorm:
497
+ residual = attention_input
498
+ else:
499
+ residual = hidden_states
500
+
501
+ # attention_output = torch.nn.functional.dropout(
502
+ # attention_output, p=self.hidden_dropout, training=self.training # TODO: why the fuck no scale???
503
+ # )
504
+ layernorm_input = attention_output + residual
505
+
506
+ # Layer norm post the self attention.
507
+ layernorm_output = self.post_attention_layernorm(layernorm_input)
508
+
509
+ # MLP.
510
+ mlp_output = self.mlp(layernorm_output)
511
+ residual = layernorm_input
512
+
513
+ # mlp_output = torch.nn.functional.dropout(
514
+ # mlp_output, p=self.hidden_dropout, training=self.training # TODO: why the fuck no scale???
515
+ # )
516
+ output = mlp_output + residual
517
+
518
+ if use_cache:
519
+ outputs = (output,) + outputs # hidden_states, present, (attn_weights)
520
+ else:
521
+ outputs = (output,) + outputs[1:] # hidden_states, (attn_weights)
522
+
523
+ return outputs
524
+
525
+
526
+ class YalmTransformer(nn.Module):
527
+ """Transformer class."""
528
+
529
+ def __init__(self, config: YalmConfig):
530
+ super().__init__()
531
+
532
+ # Number of layers:
533
+ self.num_layers = config.num_layers
534
+
535
+ self.layers = torch.nn.ModuleList(
536
+ [YalmTransformerLayer(config, layer_idx=i) for i in range(self.num_layers)]
537
+ )
538
+
539
+ def forward(
540
+ self,
541
+ hidden_states: torch.FloatTensor,
542
+ attention_mask: Optional[torch.FloatTensor] = None,
543
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor, int]]] = None,
544
+ use_cache: Optional[bool] = None,
545
+ output_attentions: Optional[bool] = None,
546
+ output_hidden_states: Optional[bool] = None,
547
+ gradient_checkpointing: bool = False,
548
+ ):
549
+ # data format change to avoid explicit tranposes : [b s h] --> [s b h]
550
+ hidden_states = hidden_states.transpose(0, 1).contiguous()
551
+
552
+ presents = () if use_cache else None
553
+ all_attentions = () if output_attentions else None
554
+ all_hidden_states = () if output_hidden_states else None
555
+ for i, (layer, layer_past) in enumerate(zip(self.layers, past_key_values)):
556
+ if output_hidden_states:
557
+ all_hidden_states = all_hidden_states + (hidden_states,)
558
+
559
+ if gradient_checkpointing and self.training:
560
+
561
+ def create_custom_forward(module):
562
+ def custom_forward(*inputs):
563
+ # None for layer_past
564
+ return module(*inputs, use_cache, None, output_attentions)
565
+
566
+ return custom_forward
567
+
568
+ outputs = torch.utils.checkpoint.checkpoint(
569
+ create_custom_forward(layer),
570
+ hidden_states,
571
+ attention_mask,
572
+ )
573
+ else:
574
+ outputs = layer(
575
+ hidden_states,
576
+ attention_mask=attention_mask,
577
+ layer_past=layer_past,
578
+ use_cache=use_cache,
579
+ output_attentions=output_attentions,
580
+ )
581
+ hidden_states = outputs[0]
582
+ if use_cache is True:
583
+ presents = presents + (outputs[1],)
584
+ if output_attentions:
585
+ all_attentions = all_attentions + (outputs[2 if use_cache else 1],)
586
+
587
+ if output_hidden_states:
588
+ all_hidden_states = all_hidden_states + (hidden_states,)
589
+
590
+ # reverting data format change [s b h] --> [b s h]
591
+ output = hidden_states.transpose(0, 1).contiguous()
592
+
593
+ return output, presents, all_hidden_states, all_attentions
594
+
595
+
596
+ class YalmProjector(nn.Module):
597
+ def __init__(self, config: YalmConfig, dtype, device):
598
+ super().__init__()
599
+
600
+ self.embedding_size = config.embedding_size
601
+ self.hidden_size = config.hidden_size
602
+ self.apply_residual_connection_post_layernorm = (
603
+ config.apply_residual_connection_post_layernorm
604
+ )
605
+
606
+ if not self.apply_residual_connection_post_layernorm:
607
+ self.input_layernorm = nn.LayerNorm(
608
+ config.embedding_size, eps=config.layernorm_epsilon
609
+ )
610
+
611
+ if config.embedding_size != config.hidden_size:
612
+ self.register_buffer(
613
+ "projector",
614
+ torch.eye(
615
+ config.embedding_size,
616
+ config.hidden_size,
617
+ ),
618
+ persistent=False,
619
+ )
620
+
621
+ def forward(self, data):
622
+ if self.apply_residual_connection_post_layernorm:
623
+ hidden_states = data
624
+ else:
625
+ hidden_states = self.input_layernorm(data)
626
+
627
+ if self.embedding_size != self.hidden_size:
628
+ hidden_states = hidden_states @ self.projector
629
+
630
+ return hidden_states
631
+
632
+
633
+ class YalmOutputLayer(nn.Module):
634
+ def __init__(self, config: YalmConfig) -> None:
635
+ super().__init__()
636
+ self.input_layer_norm = nn.LayerNorm(
637
+ config.hidden_size, eps=config.layernorm_epsilon
638
+ )
639
+
640
+ self.dense = nn.Linear(
641
+ config.hidden_size,
642
+ config.embedding_size,
643
+ )
644
+
645
+ self.activation = torch.nn.functional.gelu
646
+
647
+ self.output_layer_norm = nn.LayerNorm(
648
+ config.embedding_size,
649
+ eps=config.layernorm_epsilon,
650
+ )
651
+
652
+ def forward(self, input_data):
653
+ output = self.input_layer_norm(input_data)
654
+ output = self.dense(output)
655
+ output = self.activation(output)
656
+ output = self.output_layer_norm(output)
657
+ return output
658
+
659
+
660
+ YALM_START_DOCSTRING = r"""
661
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
662
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
663
+ etc.)
664
+
665
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
666
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
667
+ and behavior.
668
+
669
+ Parameters:
670
+ config ([`YalmConfig`]):
671
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
672
+ load the weights associated with the model, only the configuration. Check out the
673
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
674
+ """
675
+
676
+
677
+ @add_start_docstrings(
678
+ "The bare Yalm Model outputting raw hidden-states without any specific head on top.",
679
+ YALM_START_DOCSTRING,
680
+ )
681
+ class YalmPreTrainedModel(PreTrainedModel):
682
+ config_class = YalmConfig
683
+ base_model_prefix = "yalm"
684
+ supports_gradient_checkpointing = True
685
+ _no_split_modules = ["YalmTransformerLayer"]
686
+
687
+ def _init_weights(self, module):
688
+ std = self.config.initializer_range
689
+ if isinstance(module, nn.Linear):
690
+ module.weight.data.normal_(mean=0.0, std=std)
691
+ if module.bias is not None:
692
+ module.bias.data.zero_()
693
+ elif isinstance(module, nn.Embedding):
694
+ module.weight.data.normal_(mean=0.0, std=std)
695
+ if module.padding_idx is not None:
696
+ module.weight.data[module.padding_idx].zero_()
697
+
698
+ def _set_gradient_checkpointing(self, module, value=False):
699
+ if isinstance(module, YalmModel):
700
+ module.gradient_checkpointing = value
701
+
702
+
703
+ YALM_INPUTS_DOCSTRING = r"""
704
+ Args:
705
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
706
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
707
+ it.
708
+
709
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
710
+ [`PreTrainedTokenizer.__call__`] for details.
711
+
712
+ [What are input IDs?](../glossary#input-ids)
713
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
714
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
715
+
716
+ - 1 for tokens that are **not masked**,
717
+ - 0 for tokens that are **masked**.
718
+
719
+ [What are attention masks?](../glossary#attention-mask)
720
+
721
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
722
+ [`PreTrainedTokenizer.__call__`] for details.
723
+
724
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
725
+ `past_key_values`).
726
+
727
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
728
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
729
+ information on the default strategy.
730
+
731
+ - 1 indicates the head is **not masked**,
732
+ - 0 indicates the head is **masked**.
733
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
734
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
735
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
736
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
737
+
738
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
739
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
740
+
741
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
742
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
743
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
744
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
745
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
746
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
747
+ model's internal embedding lookup matrix.
748
+ use_cache (`bool`, *optional*):
749
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
750
+ `past_key_values`).
751
+ output_attentions (`bool`, *optional*):
752
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
753
+ tensors for more detail.
754
+ output_hidden_states (`bool`, *optional*):
755
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
756
+ more detail.
757
+ return_dict (`bool`, *optional*):
758
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
759
+ """
760
+
761
+
762
+ @add_start_docstrings(
763
+ "The bare YaLM Model outputting raw hidden-states without any specific head on top.",
764
+ YALM_START_DOCSTRING,
765
+ )
766
+ class YalmModel(YalmPreTrainedModel):
767
+ """
768
+ Transformer decoder consisting of *config.num_layers* layers. Each layer is a [`YalmDecoderLayer`]
769
+
770
+ Args:
771
+ config: YalmConfig
772
+ """
773
+
774
+ def __init__(self, config: YalmConfig):
775
+ super().__init__(config)
776
+ self.padding_idx = config.pad_token_id
777
+ self.padded_vocab_size = config.padded_vocab_size
778
+
779
+ self.embed_tokens = nn.Embedding(
780
+ config.padded_vocab_size, config.embedding_size, self.padding_idx
781
+ )
782
+ self.projector = YalmProjector(
783
+ config, self.embed_tokens.weight.dtype, self.embed_tokens.weight.device
784
+ )
785
+ self.transformer = YalmTransformer(config)
786
+ self.output_layer = YalmOutputLayer(config)
787
+
788
+ self.gradient_checkpointing = False
789
+
790
+ # Initialize weights and apply final processing
791
+ self.post_init()
792
+
793
+ def get_input_embeddings(self):
794
+ return self.embed_tokens
795
+
796
+ def set_input_embeddings(self, value):
797
+ self.embed_tokens = value
798
+
799
+ def _prepare_decoder_attention_mask(
800
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
801
+ ):
802
+ # create causal mask
803
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
804
+ combined_attention_mask = None
805
+ if input_shape[-1] > 1:
806
+ combined_attention_mask = _make_causal_mask(
807
+ input_shape,
808
+ inputs_embeds.dtype,
809
+ device=inputs_embeds.device,
810
+ past_key_values_length=past_key_values_length,
811
+ )
812
+
813
+ if attention_mask is not None:
814
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
815
+ expanded_attn_mask = _expand_mask(
816
+ attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
817
+ ).to(inputs_embeds.device)
818
+ combined_attention_mask = (
819
+ expanded_attn_mask
820
+ if combined_attention_mask is None
821
+ else expanded_attn_mask + combined_attention_mask
822
+ )
823
+
824
+ return combined_attention_mask
825
+
826
+ @add_start_docstrings_to_model_forward(YALM_INPUTS_DOCSTRING)
827
+ def forward(
828
+ self,
829
+ input_ids: Optional[torch.LongTensor] = None,
830
+ attention_mask: Optional[torch.FloatTensor] = None,
831
+ inputs_embeds: Optional[torch.FloatTensor] = None,
832
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
833
+ use_cache: Optional[bool] = None,
834
+ output_attentions: Optional[bool] = None,
835
+ output_hidden_states: Optional[bool] = None,
836
+ return_dict: Optional[bool] = None,
837
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
838
+ output_attentions = (
839
+ output_attentions
840
+ if output_attentions is not None
841
+ else self.config.output_attentions
842
+ )
843
+ output_hidden_states = (
844
+ output_hidden_states
845
+ if output_hidden_states is not None
846
+ else self.config.output_hidden_states
847
+ )
848
+ return_dict = (
849
+ return_dict if return_dict is not None else self.config.use_return_dict
850
+ )
851
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
852
+
853
+ if input_ids is not None and inputs_embeds is not None:
854
+ raise ValueError(
855
+ "You cannot specify both input_ids and inputs_embeds at the same time"
856
+ )
857
+ elif input_ids is not None:
858
+ input_shape = input_ids.size()
859
+ elif inputs_embeds is not None:
860
+ input_shape = inputs_embeds.size()[:-1]
861
+ else:
862
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
863
+
864
+ batch_size, seq_length = input_shape
865
+
866
+ if inputs_embeds is None:
867
+ inputs_embeds = self.embed_tokens(input_ids)
868
+
869
+ seq_length_with_past = seq_length
870
+ past_key_values_length = 0
871
+
872
+ if past_key_values is not None:
873
+ past_key_values_length = past_key_values[0][0].shape[2]
874
+ seq_length_with_past = seq_length_with_past + past_key_values_length
875
+ else:
876
+ past_key_values = tuple(None for _ in range(self.config.num_layers))
877
+ if attention_mask is None:
878
+ attention_mask = torch.ones(
879
+ (batch_size, seq_length_with_past),
880
+ dtype=torch.bool,
881
+ device=inputs_embeds.device,
882
+ )
883
+ attention_mask = self._prepare_decoder_attention_mask(
884
+ attention_mask,
885
+ (batch_size, seq_length),
886
+ inputs_embeds,
887
+ past_key_values_length,
888
+ )
889
+
890
+ hidden_states = self.projector(inputs_embeds)
891
+
892
+ hidden_states, presents, all_hidden_states, all_attentions = self.transformer(
893
+ hidden_states,
894
+ attention_mask=attention_mask,
895
+ past_key_values=past_key_values,
896
+ use_cache=use_cache,
897
+ output_attentions=output_attentions,
898
+ output_hidden_states=output_hidden_states,
899
+ gradient_checkpointing=self.gradient_checkpointing,
900
+ )
901
+ last_hidden_states = self.output_layer(hidden_states)
902
+ if output_hidden_states:
903
+ all_hidden_states = all_hidden_states + (last_hidden_states,)
904
+
905
+ if not return_dict:
906
+ return tuple(
907
+ v
908
+ for v in [
909
+ last_hidden_states,
910
+ presents,
911
+ all_hidden_states,
912
+ all_attentions,
913
+ ]
914
+ if v is not None
915
+ )
916
+
917
+ return BaseModelOutputWithPast(
918
+ last_hidden_state=last_hidden_states,
919
+ past_key_values=presents,
920
+ hidden_states=all_hidden_states,
921
+ attentions=all_attentions,
922
+ )
923
+
924
+
925
+ @add_start_docstrings(
926
+ """
927
+ YaLM Model with a `language modeling` head on top (linear layer with weights tied to the input
928
+ embeddings).
929
+ """,
930
+ YALM_START_DOCSTRING,
931
+ )
932
+ class YalmForCausalLM(YalmPreTrainedModel):
933
+ _tied_weights_keys = [r"yalm.embed_tokens.weight", r"lm_head.weight"]
934
+
935
+ def __init__(self, config: YalmConfig):
936
+ super().__init__(config)
937
+
938
+ self.yalm = YalmModel(config)
939
+ self.lm_head = nn.Linear(
940
+ config.embedding_size, config.padded_vocab_size, bias=False
941
+ )
942
+ self.out_bias = torch.nn.Parameter(
943
+ torch.zeros(
944
+ config.padded_vocab_size,
945
+ )
946
+ )
947
+
948
+ # Initialize weights and apply final processing
949
+ self.post_init()
950
+
951
+ def get_output_embeddings(self):
952
+ return self.lm_head
953
+
954
+ def set_output_embeddings(self, new_embeddings):
955
+ self.lm_head = new_embeddings
956
+
957
+ @add_start_docstrings_to_model_forward(
958
+ YALM_INPUTS_DOCSTRING.format("batch_size, sequence_length")
959
+ )
960
+ @replace_return_docstrings(
961
+ output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
962
+ )
963
+ def forward(
964
+ self,
965
+ input_ids: Optional[torch.LongTensor] = None,
966
+ attention_mask: Optional[torch.FloatTensor] = None,
967
+ inputs_embeds: Optional[torch.FloatTensor] = None,
968
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
969
+ labels: Optional[torch.LongTensor] = None,
970
+ use_cache: Optional[bool] = None,
971
+ output_attentions: Optional[bool] = None,
972
+ output_hidden_states: Optional[bool] = None,
973
+ return_dict: Optional[bool] = None,
974
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
975
+ r"""
976
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
977
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
978
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
979
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional tensors are
980
+ only required when the model is used as a decoder in a Sequence to Sequence model.
981
+
982
+ Contains pre-computed hidden-states (key and values in the self-attention blocks that can be used (see
983
+ `past_key_values` input) to speed up sequential decoding.
984
+
985
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
986
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
987
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
988
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
989
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
990
+ `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
991
+ ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`.
992
+ use_cache (`bool`, *optional*):
993
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
994
+ `past_key_values`).
995
+
996
+ Returns:
997
+
998
+ Example:
999
+
1000
+ ```python
1001
+ >>> from transformers import AutoTokenizer, YalmForCausalLM, YalmConfig
1002
+ >>> import torch
1003
+
1004
+ >>> tokenizer = AutoTokenizer.from_pretrained("TODO")
1005
+ >>> config = YalmConfig.from_pretrained("TODO")
1006
+ >>> config.is_decoder = True
1007
+ >>> model = YalmForCausalLM.from_pretrained("TODO", config=config)
1008
+
1009
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1010
+ >>> outputs = model(**inputs)
1011
+
1012
+ >>> prediction_logits = outputs.logits
1013
+ ```"""
1014
+ return_dict = (
1015
+ return_dict if return_dict is not None else self.config.use_return_dict
1016
+ )
1017
+
1018
+ outputs = self.yalm(
1019
+ input_ids,
1020
+ attention_mask=attention_mask,
1021
+ inputs_embeds=inputs_embeds,
1022
+ past_key_values=past_key_values,
1023
+ use_cache=use_cache,
1024
+ output_attentions=output_attentions,
1025
+ output_hidden_states=output_hidden_states,
1026
+ return_dict=return_dict,
1027
+ )
1028
+
1029
+ hidden_states = outputs[0]
1030
+ lm_logits = self.lm_head(hidden_states) + self.out_bias
1031
+
1032
+ lm_loss = None
1033
+ if labels is not None:
1034
+ # we are doing next-token prediction; shift prediction scores and input ids by one
1035
+ shift_logits = lm_logits[:, :-1, :].contiguous()
1036
+ labels = labels[:, 1:].contiguous()
1037
+ loss_fct = CrossEntropyLoss()
1038
+ lm_loss = loss_fct(
1039
+ shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
1040
+ )
1041
+
1042
+ if not return_dict:
1043
+ output = (lm_logits,) + outputs[1:]
1044
+ return ((lm_loss,) + output) if lm_loss is not None else output
1045
+
1046
+ return CausalLMOutputWithPast(
1047
+ loss=lm_loss,
1048
+ logits=lm_logits,
1049
+ past_key_values=outputs.past_key_values,
1050
+ hidden_states=outputs.hidden_states,
1051
+ attentions=outputs.attentions,
1052
+ )
1053
+
1054
+ def prepare_inputs_for_generation(
1055
+ self, input_ids, past_key_values=None, attention_mask=None, **kwargs
1056
+ ):
1057
+ input_shape = input_ids.shape
1058
+
1059
+ # cut decoder_input_ids if past is used
1060
+ if past_key_values and past_key_values[0] is not None:
1061
+ input_ids = input_ids[:, -1:]
1062
+
1063
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1064
+ if attention_mask is None:
1065
+ attention_mask = input_ids.new_ones(input_shape)
1066
+
1067
+ return {
1068
+ "input_ids": input_ids,
1069
+ "attention_mask": attention_mask,
1070
+ "past_key_values": past_key_values,
1071
+ }
1072
+
1073
+ def _reorder_cache(self, past_key_values, beam_idx):
1074
+ reordered_past = ()
1075
+ for layer_past in past_key_values:
1076
+ reordered_past += (
1077
+ tuple(
1078
+ past_state.index_select(0, beam_idx)
1079
+ for past_state in layer_past[:2]
1080
+ )
1081
+ + layer_past[2:],
1082
+ )
1083
+ return reordered_past
spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2e63c3d3f3883978bc756b8d8d75183923e17fc90fa76c61bcafa0ddb5dcc2b4
3
+ size 2815034
tokenization_yalm.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 T5 Authors and HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Tokenization class for model T5."""
16
+
17
+
18
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional
19
+
20
+ import numpy as np
21
+ import sentencepiece as spm
22
+ import six
23
+ from transformers.convert_slow_tokenizer import import_protobuf
24
+ from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
25
+
26
+ if TYPE_CHECKING:
27
+ from transformers.tokenization_utils_base import TextInput
28
+
29
+ from transformers.utils import logging
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+ VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
34
+
35
+ PRETRAINED_VOCAB_FILES_MAP = {
36
+ "vocab_file": {
37
+ "t5-small": "https://huggingface.co/t5-small/resolve/main/spiece.model",
38
+ "t5-base": "https://huggingface.co/t5-base/resolve/main/spiece.model",
39
+ "t5-large": "https://huggingface.co/t5-large/resolve/main/spiece.model",
40
+ "t5-3b": "https://huggingface.co/t5-3b/resolve/main/spiece.model",
41
+ "t5-11b": "https://huggingface.co/t5-11b/resolve/main/spiece.model",
42
+ }
43
+ }
44
+
45
+
46
+ # TODO(PVP) - this should be removed in Transformers v5
47
+ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
48
+ "t5-small": 512,
49
+ "t5-base": 512,
50
+ "t5-large": 512,
51
+ "t5-3b": 512,
52
+ "t5-11b": 512,
53
+ }
54
+
55
+
56
+ def convert_to_unicode(text):
57
+ """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
58
+ return six.ensure_text(text, errors="ignore")
59
+
60
+
61
+ class YalmTokenizer(PreTrainedTokenizer):
62
+ """
63
+ Construct a YaLM tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece).
64
+
65
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
66
+ this superclass for more information regarding those methods.
67
+
68
+ Args:
69
+ vocab_file (`str`):
70
+ [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
71
+ contains the vocabulary necessary to instantiate a tokenizer.
72
+ eos_token (`str`, *optional*, defaults to `"</s>"`):
73
+ The end of sequence token.
74
+
75
+ <Tip>
76
+
77
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
78
+ The token used is the `sep_token`.
79
+
80
+ </Tip>
81
+
82
+ unk_token (`str`, *optional*, defaults to `"<unk>"`):
83
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
84
+ token instead.
85
+ pad_token (`str`, *optional*, defaults to `"<pad>"`):
86
+ The token used for padding, for example when batching sequences of different lengths.
87
+ extra_ids (`int`, *optional*, defaults to 100):
88
+ Add a number of extra ids added to the vocabulary for use as sentinels. These tokens are
89
+ accessible as "<extra_id_{%d}>" where "{%d}" is a number between 0 and extra_ids-1. These tokens can be
90
+ retrieved by calling get_sentinel_tokens method and token ids can be by calling get_sentinel_token_ids
91
+ method
92
+ additional_special_tokens (`List[str]`, *optional*):
93
+ Additional special tokens used by the tokenizer.
94
+ sp_model_kwargs (`dict`, *optional*):
95
+ Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
96
+ SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
97
+ to set:
98
+
99
+ - `enable_sampling`: Enable subword regularization.
100
+ - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
101
+
102
+ - `nbest_size = {0,1}`: No sampling is performed.
103
+ - `nbest_size > 1`: samples from the nbest_size results.
104
+ - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
105
+ using forward-filtering-and-backward-sampling algorithm.
106
+
107
+ - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
108
+ BPE-dropout.
109
+
110
+ Attributes:
111
+ sp_model (`SentencePieceProcessor`):
112
+ The *SentencePiece* processor that is used for every conversion (string, tokens and IDs).
113
+ """
114
+
115
+ vocab_files_names = VOCAB_FILES_NAMES
116
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
117
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
118
+ model_input_names = ["input_ids", "attention_mask"]
119
+ SPIECE_UNDERLINE = r"▁"
120
+
121
+ def __init__(
122
+ self,
123
+ vocab_file,
124
+ bos_token="<s>",
125
+ eos_token="</s>",
126
+ unk_token="<unk>",
127
+ mask_token="[MASK]",
128
+ pad_token=None,
129
+ additional_special_tokens=None,
130
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
131
+ add_bos_token=True,
132
+ add_eos_token=False,
133
+ **kwargs,
134
+ ) -> None:
135
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
136
+ bos_token = (
137
+ AddedToken(bos_token, lstrip=False, rstrip=False)
138
+ if isinstance(bos_token, str)
139
+ else bos_token
140
+ )
141
+ eos_token = (
142
+ AddedToken(eos_token, lstrip=False, rstrip=False)
143
+ if isinstance(eos_token, str)
144
+ else eos_token
145
+ )
146
+ unk_token = (
147
+ AddedToken(unk_token, lstrip=False, rstrip=False)
148
+ if isinstance(unk_token, str)
149
+ else unk_token
150
+ )
151
+ pad_token = (
152
+ AddedToken(pad_token, lstrip=False, rstrip=False)
153
+ if isinstance(pad_token, str)
154
+ else pad_token
155
+ )
156
+ super().__init__(
157
+ bos_token=bos_token,
158
+ eos_token=eos_token,
159
+ unk_token=unk_token,
160
+ pad_token=pad_token,
161
+ add_bos_token=add_bos_token,
162
+ add_eos_token=add_eos_token,
163
+ mask_token=mask_token,
164
+ additional_special_tokens=additional_special_tokens,
165
+ sp_model_kwargs=self.sp_model_kwargs,
166
+ legacy=False,
167
+ **kwargs,
168
+ )
169
+
170
+ self.vocab_file = vocab_file
171
+ self.sp_model = self.get_spm_processor()
172
+ self._vocab_words = self._get_vocab_words()
173
+ self.encoder = {token: idx for idx, token in enumerate(self._vocab_words)}
174
+ self.decoder = {idx: token for idx, token in enumerate(self._vocab_words)}
175
+
176
+ def get_spm_processor(self):
177
+ tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs)
178
+
179
+ with open(self.vocab_file, "rb") as f:
180
+ sp_model = f.read()
181
+ model_pb2 = import_protobuf()
182
+ model = model_pb2.ModelProto.FromString(sp_model)
183
+ normalizer_spec = model_pb2.NormalizerSpec()
184
+ normalizer_spec.add_dummy_prefix = False
185
+ model.normalizer_spec.MergeFrom(normalizer_spec)
186
+ sp_model = model.SerializeToString()
187
+ tokenizer.LoadFromSerializedProto(sp_model)
188
+ return tokenizer
189
+
190
+ @property
191
+ def vocab_size(self):
192
+ return self.sp_model.get_piece_size()
193
+
194
+ def get_vocab(self):
195
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
196
+ vocab.update(self.added_tokens_encoder)
197
+ return vocab
198
+
199
+ def __getstate__(self):
200
+ state = self.__dict__.copy()
201
+ state["sp_model"] = None
202
+ return state
203
+
204
+ def __setstate__(self, d):
205
+ self.__dict__ = d
206
+
207
+ # for backward compatibility
208
+ if not hasattr(self, "sp_model_kwargs"):
209
+ self.sp_model_kwargs = {}
210
+
211
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
212
+ self.sp_model.Load(self.vocab_file)
213
+
214
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.tokenize
215
+ def tokenize(
216
+ self, text: "TextInput", add_special_tokens=False, **kwargs
217
+ ) -> List[str]:
218
+ """
219
+ Converts a string to a list of tokens.
220
+ """
221
+ text = convert_to_unicode(text)
222
+ text = text.replace("\n", "[NL]")
223
+ return [self.bos_token] + self.sp_model.encode(
224
+ YalmTokenizer.SPIECE_UNDERLINE + text, out_type=str
225
+ )
226
+
227
+ def decode(
228
+ self,
229
+ token_ids,
230
+ **kwargs,
231
+ ) -> str:
232
+ tokens = [self.decoder[idx] for idx in token_ids]
233
+ text = (
234
+ "".join(tokens)
235
+ .replace("\u2581", " ")
236
+ .replace(self.eos_token, "")
237
+ .lstrip()
238
+ .replace("[NL]", "\n")
239
+ )
240
+ return text
241
+
242
+ def _convert_token_to_id(self, token):
243
+ return self.sp_model.piece_to_id(token)
244
+
245
+ def _convert_id_to_token(self, index: int) -> str:
246
+ return self.decoder[index]
247
+
248
+ def _get_vocab_words(self):
249
+ indices = list(range(self.sp_model.GetPieceSize()))
250
+ return self.sp_model.id_to_piece(indices)