luxq commited on
Commit
93e390f
1 Parent(s): cd5caea

upload config, tokenizer and modeling file

Browse files
Files changed (6) hide show
  1. README.md +3 -3
  2. config.json +4 -0
  3. configuration_geblm.py +60 -0
  4. modeling_geb.py +1185 -0
  5. tokenization_geb.py +280 -0
  6. tokenizer_config.json +12 -0
README.md CHANGED
@@ -1,3 +1,3 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ { "auto_map": {
2
+ "AutoConfig": "configuration_geblm.GEBConfig",
3
+ "AutoModel": "modeling_geb.GEBForCausalLM"
4
+ }, "num_layers": 24, "padded_vocab_size": 64896, "hidden_size": 2048, "ffn_hidden_size": 5632, "kv_channels": 128, "num_attention_heads": 16, "torch_dtype": "bfloat16", "seq_length": 4096, "hidden_dropout": 0.0, "attention_dropout": 0.0, "layernorm_epsilon": 1e-05, "max_position_embeddings": 4096, "bias_dropout_fusion": true, "use_cache": true, "apply_residual_connection_post_layernorm": false, "post_layer_norm": true, "add_bias_linear": false, "use_flash_attn": false, "num_key_value_heads": 4, "apply_query_key_layer_scaling": false, "attention_softmax_in_fp32": false, "fp32_residual_connection": false, "pre_seq_len": null, "prefix_projection": false, "tie_word_embeddings": false, "return_dict": true, "output_hidden_states": false, "output_attentions": false, "torchscript": false, "use_bfloat16": true, "tf_legacy_loss": false, "pruned_heads": {}, "is_encoder_decoder": false, "is_decoder": false, "cross_attention_hidden_size": null, "add_cross_attention": false, "tie_encoder_decoder": false, "max_length": 512, "min_length": 0, "do_sample": true, "early_stopping": false, "num_beams": 1, "num_beam_groups": 1, "diversity_penalty": 0.0, "temperature": 0.3, "top_k": 5, "top_p": 0.5, "typical_p": 1.0, "repetition_penalty": 1.15, "length_penalty": 1.0, "no_repeat_ngram_size": 0, "encoder_no_repeat_ngram_size": 0, "bad_words_ids": null, "num_return_sequences": 1, "chunk_size_feed_forward": 0, "output_scores": false, "return_dict_in_generate": false, "forced_bos_token_id": null, "forced_eos_token_id": null, "remove_invalid_values": false, "exponential_decay_length_penalty": null, "suppress_tokens": null, "begin_suppress_tokens": null, "architectures": ["GEBForCausalLM"], "finetuning_task": null, "id2label": {"0": "LABEL_0", "1": "LABEL_1"}, "label2id": {"LABEL_0": 0, "LABEL_1": 1}, "tokenizer_class": null, "prefix": null, "bos_token_id": 1, "eos_token_id": 2, "pad_token_id": 2, "_name_or_path": "", "transformers_version": "4.35.2", "model_type": "geblm"}
configuration_geblm.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class GEBConfig(PretrainedConfig):
5
+ model_type = "geblm"
6
+ def __init__(
7
+ self,
8
+ num_layers=24,
9
+ padded_vocab_size=64896,
10
+ hidden_size=2048,
11
+ ffn_hidden_size=5632,
12
+ kv_channels=128,
13
+ num_attention_heads=16,
14
+ torch_dtype='bfloat16',
15
+ seq_length=4096,
16
+ hidden_dropout=0.0,
17
+ attention_dropout=0.0,
18
+ layernorm_epsilon=1e-5,
19
+ max_position_embeddings=4096,
20
+ bias_dropout_fusion=True,
21
+ use_cache=True,
22
+ apply_residual_connection_post_layernorm=False,
23
+ post_layer_norm=True,
24
+ add_bias_linear=False,
25
+ use_flash_attn=True,
26
+ num_key_value_heads=4,
27
+ apply_query_key_layer_scaling=False,
28
+ attention_softmax_in_fp32=False,
29
+ fp32_residual_connection=False,
30
+ pre_seq_len=None,
31
+ prefix_projection=False,
32
+ tie_word_embeddings=False,
33
+ **kwargs
34
+ ):
35
+ self.num_layers=num_layers
36
+ self.padded_vocab_size=padded_vocab_size
37
+ self.hidden_size=hidden_size
38
+ self.ffn_hidden_size=ffn_hidden_size
39
+ self.kv_channels=kv_channels
40
+ self.num_attention_heads=num_attention_heads
41
+ self.torch_dtype=torch_dtype
42
+ self.seq_length=seq_length
43
+ self.hidden_dropout=hidden_dropout,
44
+ self.attention_dropout=attention_dropout
45
+ self.layernorm_epsilon=layernorm_epsilon
46
+ self.max_position_embeddings=max_position_embeddings
47
+ self.bias_dropout_fusion=bias_dropout_fusion
48
+ self.use_cache=use_cache
49
+ self.apply_residual_connection_post_layernorm=apply_residual_connection_post_layernorm
50
+ self.post_layer_norm=post_layer_norm
51
+ self.add_bias_linear=add_bias_linear
52
+ self.use_flash_attn=use_flash_attn
53
+ self.num_key_value_heads=num_key_value_heads
54
+ self.apply_query_key_layer_scaling=apply_query_key_layer_scaling
55
+ self.attention_softmax_in_fp32=attention_softmax_in_fp32
56
+ self.fp32_residual_connection=fp32_residual_connection
57
+ self.pre_seq_len=pre_seq_len
58
+ self.prefix_projection=prefix_projection
59
+ self.tie_word_embeddings=tie_word_embeddings
60
+ super().__init__(**kwargs)
modeling_geb.py ADDED
@@ -0,0 +1,1185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PyTorch GEB model."""
2
+
3
+ import math
4
+ import copy
5
+ import os
6
+ import warnings
7
+ from dataclasses import dataclass
8
+ from typing import Optional, Tuple, Dict, Any, List
9
+ import importlib.util
10
+ from torch.nn.utils import skip_init
11
+ import torch.nn.functional as F
12
+ import torch
13
+ import torch.utils.checkpoint
14
+ from torch import einsum, nn
15
+ from torch.cuda.amp import autocast
16
+ from torch.nn import BCEWithLogitsLoss, LayerNorm, CrossEntropyLoss, MSELoss
17
+ from copy import deepcopy
18
+ from deepspeed.accelerator import get_accelerator
19
+ try:
20
+ from einops import rearrange
21
+ except ImportError:
22
+ rearrange = None
23
+ from transformers.modeling_outputs import (
24
+ BaseModelOutputWithPast,
25
+ CausalLMOutputWithPast,
26
+ QuestionAnsweringModelOutput,
27
+ SequenceClassifierOutputWithPast,
28
+ TokenClassifierOutput,
29
+ )
30
+ from transformers.modeling_utils import PreTrainedModel
31
+ from transformers.utils import (
32
+ ModelOutput,
33
+ add_code_sample_docstrings,
34
+ add_start_docstrings,
35
+ add_start_docstrings_to_model_forward,
36
+ logging,
37
+ replace_return_docstrings,
38
+ )
39
+ from transformers.generation.logits_process import LogitsProcessor
40
+ from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
41
+ from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
42
+ from .configuration_geblm import GEBConfig
43
+ try:
44
+ # FlashAttention-2
45
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func
46
+ except ImportError:
47
+ flash_attn_varlen_func = None
48
+ FlashAttentionBuilder = get_accelerator().get_op_builder("FlashAttentionBuilder")
49
+ flash_attn_builder = None
50
+
51
+
52
+ logger = logging.get_logger(__name__)
53
+
54
+ _CHECKPOINT_FOR_DOC = "geb"
55
+ _CONFIG_FOR_DOC = "GEBConfig"
56
+
57
+ def _config_to_kwargs(args):
58
+ common_kwargs = {
59
+ "dtype": args.torch_dtype,
60
+ }
61
+ return common_kwargs
62
+
63
+ def default_init(cls, *args, **kwargs):
64
+ return cls(*args, **kwargs)
65
+
66
+ class InvalidScoreLogitsProcessor(LogitsProcessor):
67
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
68
+ if torch.isnan(scores).any() or torch.isinf(scores).any():
69
+ scores.zero_()
70
+ scores[..., 5] = 5e4
71
+ return scores
72
+
73
+ def split_tensor_along_last_dim(
74
+ tensor: torch.Tensor,
75
+ num_partitions: int,
76
+ contiguous_split_chunks: bool = False,
77
+ ) -> List[torch.Tensor]:
78
+ """ Split a tensor along its last dimension.
79
+
80
+ Arguments:
81
+ tensor: input tensor.
82
+ num_partitions: number of partitions to split the tensor
83
+ contiguous_split_chunks: If True, make each chunk contiguous
84
+ in memory.
85
+
86
+ Returns:
87
+ A list of Tensors
88
+ """
89
+ # Get the size and dimension.
90
+ last_dim = tensor.dim() - 1
91
+ last_dim_size = tensor.size()[last_dim] // num_partitions
92
+ # Split.
93
+ tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
94
+ # Note: torch.split does not create contiguous tensors by default.
95
+ if contiguous_split_chunks:
96
+ return tuple(chunk.contiguous() for chunk in tensor_list)
97
+
98
+ return tensor_list
99
+
100
+ class PrefixEncoder(torch.nn.Module):
101
+ """
102
+ The torch.nn model to encode the prefix
103
+ Input shape: (batch-size, prefix-length)
104
+ Output shape: (batch-size, prefix-length, 2*layers*hidden)
105
+ """
106
+
107
+ def __init__(self, config: GEBConfig):
108
+ super().__init__()
109
+ self.prefix_projection = config.prefix_projection
110
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
111
+ if self.prefix_projection:
112
+ # Use a two-layer MLP to encode the prefix
113
+ kv_size = config.num_layers * config.kv_channels * self.num_key_value_groups * 2
114
+ self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size)
115
+ self.trans = torch.nn.Sequential(
116
+ torch.nn.Linear(kv_size, config.hidden_size),
117
+ torch.nn.Tanh(),
118
+ torch.nn.Linear(config.hidden_size, kv_size)
119
+ )
120
+ else:
121
+ self.embedding = torch.nn.Embedding(config.pre_seq_len,
122
+ config.num_layers * config.kv_channels * self.num_key_value_groups * 2)
123
+
124
+ def forward(self, prefix: torch.Tensor):
125
+ if self.prefix_projection:
126
+ prefix_tokens = self.embedding(prefix)
127
+ past_key_values = self.trans(prefix_tokens)
128
+ else:
129
+ past_key_values = self.embedding(prefix)
130
+ return past_key_values
131
+
132
+ # class RotaryEmbedding(nn.Module):
133
+ # def __init__(self, dim, original_impl=False, device=None, dtype=None):
134
+ # super().__init__()
135
+ # inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
136
+ # self.register_buffer("inv_freq", inv_freq)
137
+ # self.dim = dim
138
+ # self.original_impl = original_impl
139
+
140
+ # def forward_impl(
141
+ # self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
142
+ # ):
143
+ # """Enhanced Transformer with Rotary Position Embedding.
144
+
145
+ # Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
146
+ # transformers/rope/__init__.py. MIT License:
147
+ # https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
148
+ # """
149
+ # # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
150
+ # theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem))
151
+
152
+ # # Create position indexes `[0, 1, ..., seq_len - 1]`
153
+ # seq_idx = torch.arange(seq_len, dtype=torch.float, device=device)
154
+
155
+ # # Calculate the product of position index and $\theta_i$
156
+ # idx_theta = torch.outer(seq_idx, theta).float()
157
+
158
+ # cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
159
+
160
+ # # this is to mimic the behaviour of complex32, else we will get different results
161
+ # if dtype in (torch.float16, torch.bfloat16, torch.int8):
162
+ # cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
163
+ # return cache
164
+
165
+ # def forward(self, max_seq_len, offset=0):
166
+ # return self.forward_impl(
167
+ # max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device
168
+ # )
169
+
170
+
171
+ # @torch.jit.script
172
+ # def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
173
+ # # x: [sq, b, np, hn]
174
+ # sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
175
+ # rot_dim = rope_cache.shape[-2] * 2
176
+ # x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
177
+ # # truncate to support variable sizes
178
+ # rope_cache = rope_cache[:sq]
179
+ # xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
180
+ # rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
181
+ # x_out2 = torch.stack(
182
+ # [
183
+ # xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
184
+ # xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
185
+ # ],
186
+ # -1,
187
+ # )
188
+ # x_out2 = x_out2.flatten(3)
189
+ # return torch.cat((x_out2, x_pass), dim=-1)
190
+
191
+
192
+ class RotaryEmbedding(nn.Module):
193
+ def __init__(self, dim):
194
+ super().__init__()
195
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
196
+ self.register_buffer('inv_freq', inv_freq)
197
+ if importlib.util.find_spec('einops') is None:
198
+ raise RuntimeError("einops is required for Rotary Embedding")
199
+
200
+ def forward(self, max_seq_len, offset=0):
201
+ seq = torch.arange(max_seq_len, device=self.inv_freq.device) + offset
202
+ # Calculate the product of seq and inv_freq
203
+ freqs = einsum('i , j -> i j', seq.type_as(self.inv_freq), self.inv_freq)
204
+ # first part even vector components, second part odd vector components,
205
+ # 2 * dim in dimension size
206
+ emb = torch.cat((freqs, freqs), dim=-1)
207
+ # emb [seq_length, .., dim]
208
+ from einops import rearrange
209
+ # print('rearrange:', rearrange(emb, 'n d -> n 1 1 d').size())
210
+ return rearrange(emb, 'n d -> n 1 1 d')
211
+
212
+
213
+ def _rotate_half(x):
214
+ """
215
+ change sign so the last dimension becomes [-odd, +even]
216
+ """
217
+ from einops import rearrange
218
+ x = rearrange(x, '... (j d) -> ... j d', j=2)
219
+ x1, x2 = x.unbind(dim=-2)
220
+ return torch.cat((-x2, x1), dim=-1)
221
+
222
+
223
+ def apply_rotary_pos_emb(t, freqs):
224
+ """
225
+ input tensor t is of shape [seq_length, ..., dim]
226
+ rotary positional embeding tensor freqs is of shape [seq_length, ..., dim]
227
+ check https://kexue.fm/archives/8265 for detailed formulas
228
+ """
229
+ # print('t:', t.size())
230
+ # print('freqs:', freqs.size())
231
+ rot_dim = freqs.shape[-1]
232
+ # print('rot_dim:', rot_dim)
233
+ # ideally t_pass is empty so rotary pos embedding is applied to all tensor t
234
+ t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
235
+
236
+ # first part is cosine component
237
+ # second part is sine component, need to change signs with _rotate_half method
238
+ # print(t.shape, t_pass.shape, freqs.shape)
239
+ t = (t * freqs.cos().to(t.dtype)) + (_rotate_half(t) * freqs.sin().to(t.dtype))
240
+
241
+ return torch.cat((t, t_pass), dim=-1)
242
+
243
+
244
+ class RMSNorm(torch.nn.Module):
245
+ def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
246
+ super().__init__()
247
+ self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
248
+ self.eps = eps
249
+
250
+ def forward(self, hidden_states: torch.Tensor):
251
+ input_dtype = hidden_states.dtype
252
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
253
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
254
+
255
+ return (self.weight * hidden_states).to(input_dtype)
256
+
257
+
258
+ class MLP(torch.nn.Module):
259
+ """MLP.
260
+
261
+ MLP will take the input with h hidden state, project it to 4*h
262
+ hidden dimension, perform nonlinear transformation, and project the
263
+ state back into h hidden dimension.
264
+ """
265
+
266
+ def __init__(self, config: GEBConfig, device=None):
267
+ super(MLP, self).__init__()
268
+
269
+ self.add_bias = config.add_bias_linear #false
270
+
271
+ # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
272
+ self.dense_h_to_4h = nn.Linear(
273
+ config.hidden_size,
274
+ config.ffn_hidden_size * 2, # config.ffn_hidden_size * 2
275
+ bias=self.add_bias,
276
+ device=device,
277
+ **_config_to_kwargs(config)
278
+ )
279
+
280
+ def swiglu(x):
281
+ x = torch.chunk(x, 2, dim=-1)
282
+ return F.silu(x[0]) * x[1]
283
+
284
+ self.activation_func = swiglu
285
+
286
+ # Project back to h.
287
+ self.dense_4h_to_h = nn.Linear(
288
+ config.ffn_hidden_size,
289
+ config.hidden_size,
290
+ bias=self.add_bias,
291
+ device=device,
292
+ **_config_to_kwargs(config)
293
+ )
294
+
295
+ def forward(self, hidden_states):
296
+ # [s, b, 4hp]
297
+ intermediate_parallel = self.dense_h_to_4h(hidden_states)
298
+ intermediate_parallel = self.activation_func(intermediate_parallel)
299
+ # [s, b, h]
300
+ output = self.dense_4h_to_h(intermediate_parallel)
301
+ return output
302
+
303
+
304
+ class CoreAttention(torch.nn.Module):
305
+
306
+ def __init__(self, config: GEBConfig, layer_number):
307
+ super(CoreAttention, self).__init__()
308
+ # self.fp16 = config.fp16
309
+ # self.bf16 = config.bf16
310
+
311
+ self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
312
+ self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
313
+ if self.apply_query_key_layer_scaling:
314
+ self.attention_softmax_in_fp32 = True
315
+ self.layer_number = max(1, layer_number)
316
+ self.num_layers = config.num_layers
317
+
318
+ projection_size = config.kv_channels * config.num_attention_heads
319
+
320
+ # Per attention head and per partition values.
321
+ self.hidden_size_per_partition = projection_size
322
+ self.hidden_size_per_attention_head = projection_size // config.num_attention_heads
323
+ self.num_attention_heads_per_partition = config.num_attention_heads
324
+
325
+ coeff = None
326
+ self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
327
+ if self.apply_query_key_layer_scaling:
328
+ coeff = self.layer_number
329
+ self.norm_factor *= coeff
330
+ self.coeff = coeff
331
+ # Dropout. Note that for a single iteration, this layer will generate
332
+ # different outputs on different number of parallel partitions but
333
+ # on average it should not be partition dependent.
334
+ self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
335
+
336
+ def forward(self, query_layer, key_layer,
337
+ value_layer, attention_mask):
338
+
339
+ # ===================================
340
+ # Raw attention scores. [b, np, s, s]
341
+ # ===================================
342
+
343
+ # [b, np, sq, sk]
344
+ output_size = (query_layer.size(1),
345
+ query_layer.size(2),
346
+ query_layer.size(0),
347
+ key_layer.size(0))
348
+
349
+ # [sq, b, np, hn] -> [sq, b * np, hn]
350
+ query_layer = query_layer.view(output_size[2],
351
+ output_size[0] * output_size[1], -1)
352
+ # [sk, b, np, hn] -> [sk, b * np, hn]
353
+ key_layer = key_layer.view(output_size[3],
354
+ output_size[0] * output_size[1], -1)
355
+
356
+ # preallocting input tensor: [b * np, sq, sk],Tensor to store matrix multiplication of query and key
357
+ matmul_input_buffer = torch.empty(
358
+ output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype,
359
+ device=query_layer.device
360
+ )
361
+
362
+ # Raw attention scores. [b * np, sq, sk]
363
+ matmul_result = torch.baddbmm(
364
+ matmul_input_buffer,
365
+ query_layer.transpose(0, 1), # [b * np, sq, hn]
366
+ key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
367
+ beta=0.0, alpha=(1.0/self.norm_factor))
368
+
369
+ # change view to [b, np, sq, sk]
370
+ attention_scores = matmul_result.view(*output_size)
371
+
372
+ # ===========================
373
+ # Attention probs and dropout
374
+ # ===========================
375
+
376
+ # attention scores and attention mask [b, np, sq, sk]
377
+ if self.attention_softmax_in_fp32:
378
+ attention_scores = attention_scores.float()
379
+ if self.coeff is not None:
380
+ attention_scores = attention_scores * self.coeff
381
+ if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]:
382
+ attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3],
383
+ device=attention_scores.device, dtype=torch.bool)
384
+ attention_mask.tril_()
385
+ attention_mask = ~attention_mask
386
+ if attention_mask is not None:
387
+ attention_scores = attention_scores.masked_fill(attention_mask, float("-inf"))
388
+ attention_probs = F.softmax(attention_scores, dim=-1)
389
+ attention_probs = attention_probs.type_as(value_layer)
390
+
391
+ # This is actually dropping out entire tokens to attend to, which might
392
+ # seem a bit unusual, but is taken from the original Transformer paper.
393
+ attention_probs = self.attention_dropout(attention_probs)
394
+
395
+ # =========================
396
+ # Context layer. [sq, b, hp]
397
+ # =========================
398
+
399
+ # value_layer -> context layer.
400
+ # [sk, b, np, hn] --> [b, np, sq, hn]
401
+
402
+ # context layer shape: [b, np, sq, hn]
403
+ output_size = (value_layer.size(1),
404
+ value_layer.size(2),
405
+ query_layer.size(0),
406
+ value_layer.size(3))
407
+
408
+ # change view [sk, b * np, hn]
409
+ value_layer = value_layer.contiguous().view(value_layer.size(0),
410
+ output_size[0] * output_size[1], -1)
411
+
412
+ # change view [b * np, sq, sk]
413
+ attention_probs = attention_probs.view(output_size[0] * output_size[1],
414
+ output_size[2], -1)
415
+
416
+ # matmul: [b * np, sq, hn]
417
+ context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
418
+
419
+ # change view [b, np, sq, hn]
420
+ context_layer = context_layer.view(*output_size)
421
+
422
+ # [b, np, sq, hn] --> [sq, b, np, hn]
423
+ context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
424
+
425
+ # [sq, b, np, hn] --> [sq, b, hp]
426
+ new_context_layer_shape = context_layer.size()[:-2] + \
427
+ (self.hidden_size_per_partition,)
428
+ context_layer = context_layer.view(*new_context_layer_shape)
429
+
430
+ return context_layer
431
+
432
+ class FlashSelfAttention(torch.nn.Module):
433
+ """Implement the scaled dot product attention with softmax.
434
+ Arguments
435
+ ---------
436
+ softmax_scale: The temperature to use for the softmax attention.
437
+ (default: 1/sqrt(d_keys) where d_keys is computed at
438
+ runtime)
439
+ attention_dropout: The dropout rate to apply to the attention
440
+ (default: 0.0)
441
+ """
442
+ def __init__(self, config: GEBConfig, causal=False, softmax_scale=None, attention_dropout=0.0,
443
+ device=None, dtype=None):
444
+ super().__init__()
445
+ assert flash_attn_varlen_func is not None or flash_attn_builder is not None, \
446
+ ('Please install FlashAttention first, e.g., with pip install flash-attn or implement your own flash attention')
447
+ assert rearrange is not None, 'Please install einops first, e.g., with pip install einops'
448
+ self.config = config
449
+ self.causal = causal
450
+ self.softmax_scale = softmax_scale
451
+ self.dropout_p = attention_dropout
452
+
453
+ # Use FlashAttention-2 when args.use_flash_attn_v2 is True
454
+ self.flash_attn_func = flash_attn_varlen_func if config.use_flash_attn else print('false to Use FlashAttention-2')
455
+
456
+ def forward(self, q, k, v):
457
+ """Implements the multihead softmax attention.
458
+ Arguments
459
+ ---------
460
+ q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
461
+ """
462
+ # print(i.dtype() for i in (q,k,v) )
463
+ # assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q,k,v)))
464
+ # assert all((get_accelerator().on_accelerator(i) for i in (q, k, v)))
465
+ # if get_accelerator().device_name() == 'cuda':
466
+ # assert all((i.is_cuda for i in (q,k,v)))
467
+ # else:
468
+ # assert all((i.is_xpu for i in (q,k,v)))
469
+
470
+ batch_size, seqlen_q = q.shape[0], q.shape[1]
471
+ seqlen_k = k.shape[1]
472
+
473
+ if get_accelerator().device_name() == 'cuda':
474
+ # goes for cuda device
475
+ q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]]
476
+ cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32,
477
+ device=q.device)
478
+ else:
479
+ # goes for other device
480
+ q, k, v = [rearrange(x, 'b s h d -> b h s d').contiguous() for x in [q, k, v]]
481
+
482
+ if self.training:
483
+ # during training q,k,v always have same seqlen
484
+ assert seqlen_k == seqlen_q
485
+
486
+ is_causal = self.causal
487
+ cu_seqlens_k = cu_seqlens_q if get_accelerator().device_name() == 'cuda' else None
488
+ dropout_p = self.dropout_p
489
+ else:
490
+ # turn off FA causal mask after first inference autoregressive iteration
491
+ # only on first autoregressive step q,k,v have same seqlen
492
+ is_causal = seqlen_q == seqlen_k
493
+ cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32,
494
+ device=q.device) if get_accelerator().device_name() == 'cuda' else None
495
+ dropout_p = 0
496
+
497
+ output = self.flash_attn_func(
498
+ q, k, v, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
499
+ dropout_p,
500
+ softmax_scale=self.softmax_scale, causal=is_causal
501
+ ) if get_accelerator().device_name() == 'cuda' else flash_attn_builder.flash_attn_func(
502
+ q, k, v, self.dropout_p, self.softmax_scale, is_causal
503
+ )
504
+
505
+ output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) if get_accelerator().device_name() == 'cuda' else rearrange(
506
+ output, 'b h s d -> b s h d').contiguous()
507
+ return output
508
+
509
+ class GEBAttention(nn.Module):
510
+ """Parallel self-attention layer abstract class.
511
+ Self-attention layer takes input with size [s, b, h]
512
+ and returns output of the same size.
513
+ """
514
+ def __init__(self, config: GEBConfig, layer_number, device=None):
515
+ super().__init__()
516
+ self.config = config
517
+ self.layer_number = max(1, layer_number)
518
+ self.projection_size = config.kv_channels * config.num_attention_heads
519
+ self.use_flash_attn = config.use_flash_attn
520
+ # Per attention head and per partition values.
521
+ self.hidden_size_per_partition = self.projection_size
522
+ self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
523
+ self.num_attention_heads_per_partition = config.num_attention_heads
524
+ self.num_key_value_heads_per_partition = config.num_key_value_heads
525
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
526
+ self.kv_projection_size = config.kv_channels * config.num_key_value_heads
527
+ assert self.hidden_size_per_attention_head == self.kv_projection_size // config.num_key_value_heads
528
+ # self.max_position_embeddings = config.model_max_length
529
+ if self.use_flash_attn:
530
+ global flash_attn_builder
531
+ try:
532
+ flash_attn_builder = FlashAttentionBuilder().load()
533
+ except TypeError:
534
+ flash_attn_builder = None
535
+ assert flash_attn_varlen_func != None, "Cannot import FlashAttention v2 "
536
+ if rearrange is None:
537
+ raise ImportError('einops is not installed, please install with pip install einops')
538
+
539
+ self.query = nn.Linear(config.hidden_size, self.projection_size,
540
+ bias=config.add_bias_linear,
541
+ device=device, **_config_to_kwargs(config)
542
+ )
543
+
544
+ self.key_value = nn.Linear(config.hidden_size, 2 * self.kv_projection_size,
545
+ bias=config.add_bias_linear,
546
+ device=device, **_config_to_kwargs(config)
547
+ )
548
+
549
+ if config.use_flash_attn:
550
+ self.core_attention_flash = FlashSelfAttention(config, causal=True, attention_dropout=config.attention_dropout)
551
+ else:
552
+ self.core_attention = CoreAttention(config, self.layer_number)
553
+
554
+ self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear,
555
+ device=device, **_config_to_kwargs(config)
556
+ )
557
+
558
+ def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
559
+ return torch.empty(
560
+ inference_max_sequence_len,
561
+ batch_size,
562
+ self.num_key_value_groups,
563
+ self.hidden_size_per_attention_head,
564
+ dtype=dtype,
565
+ device=device)
566
+
567
+ def repeat_kv(self, hidden_states, n_rep):
568
+ slen, batch, num_key_value_heads_per_partition, head_dim = hidden_states.shape
569
+ if n_rep == 1:
570
+ return hidden_states
571
+ hidden_states = hidden_states[:, :, :, None, :].expand(
572
+ slen, batch, num_key_value_heads_per_partition, n_rep, head_dim)
573
+ return hidden_states.reshape(slen, batch,
574
+ num_key_value_heads_per_partition * n_rep,
575
+ head_dim)
576
+
577
+ def forward(self, hidden_states, attention_mask,
578
+ rotary_pos_emb=None, kv_cache=None, use_cache=True):
579
+ # Attention head [sq, b, h]--> [sq, b, hp]
580
+ query_layer = self.query(hidden_states)
581
+ # [sq, b, hp] --> [sq, b, np, hn]
582
+ new_tensor_shape = query_layer.size()[:-1] + \
583
+ (self.num_attention_heads_per_partition,
584
+ self.hidden_size_per_attention_head)
585
+ query_layer = query_layer.view(*new_tensor_shape)
586
+
587
+ # Attention heads [sq, b, h] --> [sq, b, (np * 2 * hn)]
588
+ mixed_kv_layer = self.key_value(hidden_states)
589
+ # [sq, b, (np * 2 * hn)] --> [sq, b, np, 2 * hn]
590
+ new_tensor_shape = mixed_kv_layer.size()[:-1] + \
591
+ (self.num_key_value_heads_per_partition,
592
+ 2 * self.hidden_size_per_attention_head)
593
+ mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)
594
+ # [sq, b, np, 2 * hn] --> 2 [sq, b, np, hn]
595
+ (key_layer,
596
+ value_layer) = split_tensor_along_last_dim(
597
+ mixed_kv_layer, 2)
598
+
599
+
600
+
601
+ # Repeat kv
602
+ key_layer = self.repeat_kv(key_layer, self.num_key_value_groups)
603
+ value_layer = self.repeat_kv(value_layer,
604
+ self.num_key_value_groups)
605
+
606
+ # if rotary_pos_emb is not None:
607
+ # query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
608
+ # key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
609
+
610
+ # duplicate the pos_emb for self attention
611
+
612
+ if rotary_pos_emb is not None:
613
+ if isinstance(rotary_pos_emb, tuple):
614
+ rotary_pos_emb = rotary_pos_emb
615
+ else:
616
+ rotary_pos_emb = ((rotary_pos_emb,) * 2)
617
+ q_pos_emb, k_pos_emb = rotary_pos_emb
618
+ query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb)
619
+ key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb)
620
+
621
+
622
+ # adjust key and value for inference
623
+ if kv_cache is not None:
624
+ cache_k, cache_v = kv_cache
625
+ key_layer = torch.cat((cache_k, key_layer), dim=0)
626
+ value_layer = torch.cat((cache_v, value_layer), dim=0)
627
+ if use_cache:
628
+ kv_cache = (key_layer, value_layer)
629
+ else:
630
+ kv_cache = None
631
+
632
+
633
+ if self.use_flash_attn:
634
+ query_layer, key_layer, value_layer = [rearrange(x, 's b ... -> b s ...').contiguous()
635
+ for x in (query_layer, key_layer, value_layer)]
636
+ context_layer = self.core_attention_flash(query_layer, key_layer, value_layer)
637
+ context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous()
638
+ else:
639
+ context_layer = self.core_attention(
640
+ query_layer, key_layer, value_layer, attention_mask)
641
+
642
+ output= self.dense(context_layer)# output, bias = self.dense(context_layer)
643
+
644
+ return output, kv_cache
645
+
646
+ class GEBBlock(torch.nn.Module):
647
+ """A single transformer layer.
648
+
649
+ Transformer layer takes input with size [s, b, h] and returns an
650
+ output of the same size.
651
+ """
652
+ def __init__(self, config: GEBConfig, layer_number, device=None):
653
+ super(GEBBlock, self).__init__()
654
+ self.layer_number = layer_number
655
+ self.apply_residual_connection_post_layernorm \
656
+ = config.apply_residual_connection_post_layernorm
657
+
658
+ # self.bf16 = config.bf16
659
+ self.fp32_residual_connection = config.fp32_residual_connection
660
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.layernorm_epsilon, device=device,
661
+ dtype=config.torch_dtype)
662
+ self.self_attention = GEBAttention(config, layer_number, device=device)
663
+ self.hidden_dropout = config.hidden_dropout
664
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.layernorm_epsilon, device=device,
665
+ dtype=config.torch_dtype)
666
+ self.mlp = MLP(config, device=device)
667
+
668
+ def forward(self, hidden_states, attention_mask=None,
669
+ rotary_pos_emb=None,
670
+ kv_cache=None,
671
+ use_cache=True):
672
+ # hidden_states: [s, b, h]
673
+ # Layer norm at the beginning of the transformer layer.
674
+ layernorm_output = self.input_layernorm(hidden_states)
675
+ # Self attention.
676
+ attention_output, kv_cache = \
677
+ self.self_attention(
678
+ layernorm_output,
679
+ attention_mask,
680
+ rotary_pos_emb=rotary_pos_emb,
681
+ kv_cache=kv_cache,
682
+ use_cache=use_cache)
683
+
684
+ # Residual connection.
685
+ if self.apply_residual_connection_post_layernorm:
686
+ residual = layernorm_output
687
+ else:
688
+ residual = hidden_states
689
+
690
+ layernorm_input = torch.nn.functional.dropout(attention_output,
691
+ p=0.0,
692
+ training=self.training)
693
+ layernorm_input = residual + layernorm_input
694
+ # Layer norm post the self attention.
695
+ layernorm_output = self.post_attention_layernorm(layernorm_input)
696
+ # MLP.
697
+ mlp_output = self.mlp(layernorm_output)
698
+ # Second residual connection.
699
+ if self.apply_residual_connection_post_layernorm:
700
+ residual = layernorm_output
701
+ else:
702
+ residual = layernorm_input
703
+ out = torch.nn.functional.dropout(mlp_output,
704
+ p=0.0,
705
+ training=self.training)
706
+ output = residual + out
707
+ return output, kv_cache
708
+
709
+ class GEBTransformer(torch.nn.Module):
710
+ """Transformer class."""
711
+
712
+ def __init__(self, config: GEBConfig, device=None):
713
+ super(GEBTransformer, self).__init__()
714
+
715
+ self.fp32_residual_connection = config.fp32_residual_connection
716
+ self.post_layer_norm = config.post_layer_norm
717
+ self.num_layers = config.num_layers
718
+ def build_layer(layer_number):
719
+ return GEBBlock(
720
+ config,
721
+ layer_number,
722
+ device=device)
723
+ # Build the layers
724
+ self.layers = []
725
+ for i in range(self.num_layers):
726
+ layer_num = i + 1
727
+ self.layers.append(build_layer(layer_num))
728
+ self.layers = torch.nn.ModuleList(self.layers)
729
+ if self.post_layer_norm:
730
+ self.final_layernorm = RMSNorm(config.hidden_size, eps=config.layernorm_epsilon, device=device,
731
+ dtype=config.torch_dtype)
732
+ self.gradient_checkpointing = False
733
+
734
+ def _get_layer(self, layer_number):
735
+ return self.layers[layer_number]
736
+
737
+ def forward(
738
+ self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None,
739
+ use_cache: Optional[bool] = True,
740
+ output_hidden_states: Optional[bool] = False,
741
+ ):
742
+ if not kv_caches:
743
+ kv_caches = [None for _ in range(self.num_layers)]
744
+ presents = () if use_cache else None
745
+ if self.gradient_checkpointing and self.training:
746
+ if use_cache:
747
+ logger.warning_once(
748
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
749
+ )
750
+ use_cache = False
751
+
752
+ all_self_attentions = None
753
+ all_hidden_states = () if output_hidden_states else None
754
+ for index in range(self.num_layers):
755
+ if output_hidden_states:
756
+ all_hidden_states = all_hidden_states + (hidden_states,)
757
+ layer = self._get_layer(index)
758
+ if self.gradient_checkpointing and self.training:
759
+ layer_hidden = torch.utils.checkpoint.checkpoint(
760
+ layer,
761
+ hidden_states,
762
+ attention_mask,
763
+ rotary_pos_emb,
764
+ kv_caches[index],
765
+ use_cache
766
+ )
767
+ else:
768
+ layer_hidden = layer(
769
+ hidden_states,
770
+ attention_mask,
771
+ rotary_pos_emb,
772
+ kv_cache=kv_caches[index],
773
+ use_cache=use_cache
774
+ )
775
+ hidden_states, kv_cache = layer_hidden
776
+ if use_cache:
777
+ presents = presents + (kv_cache,)
778
+
779
+ if output_hidden_states:
780
+ all_hidden_states = all_hidden_states + (hidden_states,)
781
+ if self.post_layer_norm:
782
+ hidden_states = self.final_layernorm(hidden_states)
783
+ return hidden_states, presents, all_hidden_states, all_self_attentions
784
+
785
+
786
+ class GEBPreTrainedModel(PreTrainedModel):
787
+ """
788
+ An abstract class to handle weights initialization and
789
+ a simple interface for downloading and loading pretrained models.
790
+ """
791
+
792
+ is_parallelizable = False
793
+ supports_gradient_checkpointing = True
794
+ config_class = GEBConfig
795
+ base_model_prefix = "transformer"
796
+ _no_split_modules = ["GEBBlock"]
797
+
798
+ def _init_weights(self, module: nn.Module):
799
+ """Initialize the weights."""
800
+ return
801
+
802
+ def get_masks(self, input_ids, past_key_values, padding_mask=None):
803
+ batch_size, seq_length = input_ids.shape
804
+ full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
805
+ full_attention_mask.tril_()
806
+ past_length = 0
807
+ if past_key_values:
808
+ past_length = past_key_values[0][0].shape[0]
809
+ if past_length:
810
+ full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length,
811
+ device=input_ids.device), full_attention_mask), dim=-1)
812
+ if padding_mask is not None:
813
+ full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
814
+ if not past_length and padding_mask is not None:
815
+ full_attention_mask -= padding_mask.unsqueeze(-1) - 1
816
+ full_attention_mask = (full_attention_mask < 0.5).bool()
817
+ full_attention_mask.unsqueeze_(1)
818
+ return full_attention_mask
819
+
820
+ def get_position_ids(self, input_ids, device):
821
+ batch_size, seq_length = input_ids.shape
822
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
823
+ return position_ids
824
+
825
+ def _set_gradient_checkpointing(self, module, value=False):
826
+ if isinstance(module, GEBTransformer):
827
+ module.gradient_checkpointing = value
828
+
829
+ class Embedding(torch.nn.Module):
830
+ """Language model embeddings."""
831
+
832
+ def __init__(self, config: GEBConfig, device=None):
833
+ super(Embedding, self).__init__()
834
+
835
+ self.hidden_size = config.hidden_size
836
+ # Word embeddings.
837
+ self.word_embeddings = nn.Embedding(
838
+ config.padded_vocab_size,
839
+ self.hidden_size,
840
+ dtype=config.torch_dtype,
841
+ device=device
842
+ )
843
+ self.fp32_residual_connection = config.fp32_residual_connection
844
+
845
+ def forward(self, input_ids):
846
+ # Embeddings.
847
+ words_embeddings = self.word_embeddings(input_ids)
848
+ embeddings = words_embeddings
849
+ # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
850
+ embeddings = embeddings.transpose(0, 1).contiguous()
851
+ # If the input flag for fp32 residual connection is set, convert for float.
852
+ if self.fp32_residual_connection:
853
+ embeddings = embeddings.float()
854
+ return embeddings
855
+
856
+ class GEBModel(GEBPreTrainedModel):
857
+ def __init__(self, config: GEBConfig, device=None, empty_init=True):
858
+ super().__init__(config)
859
+ if empty_init:
860
+ init_method = skip_init
861
+ else:
862
+ init_method = default_init
863
+ init_kwargs = {}
864
+ if device is not None:
865
+ init_kwargs["device"] = device
866
+ self.embedding = init_method(Embedding, config, **init_kwargs)
867
+ self.num_layers = config.num_layers
868
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
869
+ self.kv_channels = config.kv_channels
870
+
871
+ # Rotary positional embeddings
872
+ self.seq_length = config.seq_length
873
+ rotary_dim = (
874
+ config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
875
+ )
876
+
877
+ # self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl= True, device=device,
878
+ # dtype=config.torch_dtype)
879
+
880
+ self.rotary_pos_emb = RotaryEmbedding(rotary_dim)
881
+
882
+
883
+ self.encoder = init_method(GEBTransformer, config, **init_kwargs)
884
+ self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
885
+ dtype=config.torch_dtype, **init_kwargs)
886
+ self.pre_seq_len = config.pre_seq_len
887
+ self.prefix_projection = config.prefix_projection
888
+ if self.pre_seq_len is not None:
889
+ for param in self.parameters():
890
+ param.requires_grad = False
891
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
892
+ self.prefix_encoder = PrefixEncoder(config)
893
+ self.dropout = torch.nn.Dropout(0.1)
894
+
895
+ def get_input_embeddings(self):
896
+ return self.embedding.word_embeddings
897
+
898
+ def get_prompt(self, batch_size, device, dtype=torch.half):
899
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
900
+ past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
901
+ past_key_values = past_key_values.view(
902
+ batch_size,
903
+ self.pre_seq_len,
904
+ self.num_layers * 2,
905
+ self.num_key_value_groups,
906
+ self.kv_channels
907
+ )
908
+ # seq_len, b, nh, hidden_size
909
+ past_key_values = self.dropout(past_key_values)
910
+ past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
911
+ return past_key_values
912
+
913
+ def forward(
914
+ self,
915
+ input_ids,
916
+ position_ids: Optional[torch.Tensor] = None,
917
+ attention_mask: Optional[torch.BoolTensor] = None,
918
+ full_attention_mask: Optional[torch.BoolTensor] = None,
919
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
920
+ inputs_embeds: Optional[torch.Tensor] = None,
921
+ use_cache: Optional[bool] = None,
922
+ output_hidden_states: Optional[bool] = None,
923
+ return_dict: Optional[bool] = None,
924
+ ):
925
+ output_hidden_states = (
926
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
927
+ )
928
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
929
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
930
+
931
+ batch_size, seq_length = input_ids.shape
932
+ if inputs_embeds is None:
933
+ inputs_embeds = self.embedding(input_ids)
934
+
935
+ if self.pre_seq_len is not None:
936
+ if past_key_values is None:
937
+ past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device,
938
+ dtype=inputs_embeds.dtype)
939
+ if attention_mask is not None:
940
+ attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)),
941
+ attention_mask], dim=-1)
942
+
943
+ if full_attention_mask is None:
944
+ if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
945
+ full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
946
+
947
+ # # Rotary positional embeddings
948
+ # rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
949
+ # if position_ids is not None:
950
+ # rotary_pos_emb = rotary_pos_emb[position_ids]
951
+ # else:
952
+ # rotary_pos_emb = rotary_pos_emb[None, :seq_length]
953
+ # rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
954
+ # Rotary positional embeddings
955
+ # print(position_ids[0].tolist())
956
+ rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
957
+ rotary_pos_emb = rotary_pos_emb[position_ids[0].tolist()]
958
+ # rotary_pos_emb = self.rotary_pos_emb(position_ids.shape[-1])
959
+
960
+
961
+ # # Rotary positional embeddings emb [seq_length, .., dim] no not need transpose
962
+ # rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
963
+ # rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
964
+ # print('rotary_pos_emb:', rotary_pos_emb.size())
965
+ # if position_ids is not None:
966
+ # rotary_pos_emb = rotary_pos_emb[position_ids]
967
+ # print('rotary_pos_emb:', rotary_pos_emb.size())
968
+ # else:
969
+ # rotary_pos_emb = rotary_pos_emb[None, :seq_length]
970
+ # # rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
971
+
972
+ # Run encoder.
973
+ hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
974
+ inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
975
+ kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
976
+ )
977
+
978
+ if not return_dict:
979
+ return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
980
+
981
+ return BaseModelOutputWithPast(
982
+ last_hidden_state=hidden_states,
983
+ past_key_values=presents,
984
+ hidden_states=all_hidden_states,
985
+ attentions=all_self_attentions,
986
+ )
987
+
988
+ class GEBForCausalLM(GEBPreTrainedModel):
989
+ def __init__(self, config: GEBConfig, empty_init=True, device=None):
990
+ super().__init__(config)
991
+
992
+ self.max_sequence_length = config.max_length
993
+ self.transformer = GEBModel(config, empty_init=empty_init, device=device)
994
+ self.config = config
995
+ self.quantized = False
996
+
997
+ # if self.config.quantization_bit:
998
+ # self.quantize(self.config.quantization_bit, empty_init=True)
999
+
1000
+ def _update_model_kwargs_for_generation(
1001
+ self,
1002
+ outputs: ModelOutput,
1003
+ model_kwargs: Dict[str, Any],
1004
+ is_encoder_decoder: bool = False,
1005
+ standardize_cache_format: bool = False,
1006
+ ) -> Dict[str, Any]:
1007
+ # update past_key_values
1008
+ model_kwargs["past_key_values"] = self._extract_past_from_model_output(
1009
+ outputs, standardize_cache_format=standardize_cache_format
1010
+ )
1011
+
1012
+ # update attention mask
1013
+ if "attention_mask" in model_kwargs:
1014
+ attention_mask = model_kwargs["attention_mask"]
1015
+ model_kwargs["attention_mask"] = torch.cat(
1016
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
1017
+ )
1018
+
1019
+ # update position ids
1020
+ if "position_ids" in model_kwargs:
1021
+ position_ids = model_kwargs["position_ids"]
1022
+ new_position_id = position_ids[..., -1:].clone()
1023
+ new_position_id += 1
1024
+ model_kwargs["position_ids"] = torch.cat(
1025
+ [position_ids, new_position_id], dim=-1
1026
+ )
1027
+
1028
+ model_kwargs["is_first_forward"] = False
1029
+ return model_kwargs
1030
+
1031
+ def prepare_inputs_for_generation(
1032
+ self,
1033
+ input_ids: torch.LongTensor,
1034
+ past_key_values: Optional[torch.Tensor] = None,
1035
+ attention_mask: Optional[torch.Tensor] = None,
1036
+ position_ids: Optional[torch.Tensor] = None,
1037
+ use_cache: Optional[bool] = None,
1038
+ is_first_forward: bool = True,
1039
+ **kwargs
1040
+ ) -> dict:
1041
+ # only last token for input_ids if past is not None
1042
+ if position_ids is None:
1043
+ position_ids = self.get_position_ids(input_ids, device=input_ids.device)
1044
+ if not is_first_forward:
1045
+ if past_key_values is not None:
1046
+ position_ids = position_ids[..., -1:]
1047
+ input_ids = input_ids[:, -1:]
1048
+ return {
1049
+ "input_ids": input_ids,
1050
+ "past_key_values": past_key_values,
1051
+ "position_ids": position_ids,
1052
+ "attention_mask": attention_mask,
1053
+ "return_last_logit": True,
1054
+ "use_cache": use_cache
1055
+ }
1056
+
1057
+ def forward(
1058
+ self,
1059
+ input_ids: Optional[torch.Tensor] = None,
1060
+ position_ids: Optional[torch.Tensor] = None,
1061
+ attention_mask: Optional[torch.Tensor] = None,
1062
+ past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
1063
+ inputs_embeds: Optional[torch.Tensor] = None,
1064
+ labels: Optional[torch.Tensor] = None,
1065
+ use_cache: Optional[bool] = None,
1066
+ output_attentions: Optional[bool] = None,
1067
+ output_hidden_states: Optional[bool] = None,
1068
+ return_dict: Optional[bool] = None,
1069
+ return_last_logit: Optional[bool] = False,
1070
+ ):
1071
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1072
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1073
+
1074
+ transformer_outputs = self.transformer(
1075
+ input_ids=input_ids,
1076
+ position_ids=position_ids,
1077
+ attention_mask=attention_mask,
1078
+ past_key_values=past_key_values,
1079
+ inputs_embeds=inputs_embeds,
1080
+ use_cache=use_cache,
1081
+ output_hidden_states=output_hidden_states,
1082
+ return_dict=return_dict,
1083
+ )
1084
+
1085
+ hidden_states = transformer_outputs[0]
1086
+ if return_last_logit:
1087
+ hidden_states = hidden_states[-1:]
1088
+ lm_logits = self.transformer.output_layer(hidden_states)
1089
+ lm_logits = lm_logits.transpose(0, 1).contiguous()
1090
+
1091
+ loss = None
1092
+ if labels is not None:
1093
+ lm_logits = lm_logits.to(torch.float32)
1094
+
1095
+ # Shift so that tokens < n predict n
1096
+ shift_logits = lm_logits[..., :-1, :].contiguous()
1097
+ shift_labels = labels[..., 1:].contiguous()
1098
+ # Flatten the tokens
1099
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
1100
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1101
+
1102
+ lm_logits = lm_logits.to(hidden_states.dtype)
1103
+ loss = loss.to(hidden_states.dtype)
1104
+
1105
+ if not return_dict:
1106
+ output = (lm_logits,) + transformer_outputs[1:]
1107
+ return ((loss,) + output) if loss is not None else output
1108
+
1109
+ return CausalLMOutputWithPast(
1110
+ loss=loss,
1111
+ logits=lm_logits,
1112
+ past_key_values=transformer_outputs.past_key_values,
1113
+ hidden_states=transformer_outputs.hidden_states,
1114
+ attentions=transformer_outputs.attentions,
1115
+ )
1116
+
1117
+ @staticmethod
1118
+ def _reorder_cache(
1119
+ past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
1120
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
1121
+ """
1122
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
1123
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1124
+ beam_idx at every generation step.
1125
+
1126
+ Output shares the same memory storage as `past`.
1127
+ """
1128
+ return tuple(
1129
+ (
1130
+ layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)),
1131
+ layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)),
1132
+ )
1133
+ for layer_past in past
1134
+ )
1135
+
1136
+ def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None):
1137
+ prompt = tokenizer.build_prompt(query, history=history)
1138
+ tokens = [tokenizer.get_command("<bos>")] + tokenizer.encode(prompt)
1139
+ inputs = tokenizer.batch_encode_plus([tokens], return_tensors="pt", is_split_into_words=True)
1140
+ inputs = inputs.to(self.device)
1141
+ return inputs
1142
+
1143
+ # def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None):
1144
+ # prompt = tokenizer.build_prompt(query, history=history)
1145
+ # inputs = tokenizer([prompt], return_tensors="pt")
1146
+ # # print(inputs)
1147
+ # inputs = inputs.to(self.device)
1148
+ # return inputs
1149
+
1150
+ @torch.inference_mode()
1151
+ def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 512, num_beams=1,
1152
+ do_sample=True, top_p=0.5, temperature=0.3, logits_processor=None, repetition_penalty = 1.15, **kwargs):
1153
+ if history is None:
1154
+ history = []
1155
+ if logits_processor is None:
1156
+ logits_processor = LogitsProcessorList()
1157
+ logits_processor.append(InvalidScoreLogitsProcessor())
1158
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
1159
+ "temperature": temperature, "logits_processor": logits_processor, "repetition_penalty":repetition_penalty, **kwargs}
1160
+ prompt = tokenizer.build_prompt(query, history=[])
1161
+ system = "You are a helpful assistant.\n"
1162
+ system_ids = [
1163
+ tokenizer.get_command("<bos>")
1164
+ ] + tokenizer.encode(text=system) + [
1165
+ tokenizer.get_command("<eos>")]
1166
+
1167
+ prompt_ids = [
1168
+ tokenizer.get_command("<bos>")
1169
+ ] + tokenizer.encode(
1170
+ text=prompt,
1171
+ add_special_tokens=False
1172
+ ) + [
1173
+ tokenizer.get_command("<eos>")] + [
1174
+ tokenizer.get_command("<bos>")]
1175
+ tokens = system_ids + prompt_ids
1176
+ inputs = tokenizer.batch_encode_plus([tokens], return_tensors="pt", is_split_into_words=True)
1177
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1178
+ inputs = inputs.to(device)
1179
+ outputs = self.generate(**inputs, **gen_kwargs)
1180
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
1181
+ response = tokenizer.decode(outputs)
1182
+ return response, history
1183
+
1184
+
1185
+
tokenization_geb.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import re
4
+ from typing import List, Optional, Union, Dict
5
+ from sentencepiece import SentencePieceProcessor
6
+ from transformers import PreTrainedTokenizer
7
+ from transformers.utils import logging, PaddingStrategy
8
+ from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
9
+ class SPTokenizer:
10
+ def __init__(self, model_path: str):
11
+ # reload tokenizer
12
+ assert os.path.isfile(model_path), model_path
13
+ self.sp_model = SentencePieceProcessor(model_file=model_path)
14
+
15
+ # BOS / EOS token IDs
16
+ self.n_words: int = self.sp_model.vocab_size()
17
+ self.bos_id: int = self.sp_model.bos_id()
18
+ self.eos_id: int = self.sp_model.eos_id()
19
+ self.pad_id: int = self.sp_model.unk_id()
20
+ assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
21
+
22
+ special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "<sop>", "<eop>", "<eod>", "", "", "",
23
+ "","<bos>","<eos>"]
24
+ self.special_tokens = {}
25
+ self.index_special_tokens = {}
26
+ for token in special_tokens:
27
+ if token == "<bos>":
28
+ self.special_tokens["<bos>"] = self.bos_id
29
+ self.index_special_tokens[self.bos_id] = "<bos>"
30
+ elif token == "<eos>":
31
+ self.special_tokens["<eos>"] = self.eos_id
32
+ self.index_special_tokens[self.eos_id] = "<eos>"
33
+ else:
34
+ self.special_tokens[token] = self.n_words
35
+ self.index_special_tokens[self.n_words] = token
36
+ self.n_words += 1
37
+
38
+
39
+ def tokenize(self, s: str):
40
+ return self.sp_model.EncodeAsPieces(s)
41
+
42
+ def encode(self, s: str, bos: bool = False, eos: bool = False
43
+ ) -> List[int]:
44
+ assert type(s) is str
45
+ t = self.sp_model.encode(s)
46
+ if bos:
47
+ t = [self.bos_id] + t
48
+ if eos:
49
+ t = t + [self.eos_id]
50
+ return t
51
+
52
+ def decode(self, t: List[int]) -> str:
53
+ return self.sp_model.decode(t)
54
+
55
+ def decode_tokens(self, tokens: List[str]) -> str:
56
+ text = self.sp_model.DecodePieces(tokens)
57
+ return text
58
+
59
+ def convert_token_to_id(self, token):
60
+ """ Converts a token (str) in an id using the vocab. """
61
+ if token in self.special_tokens:
62
+ return self.special_tokens[token]
63
+ return self.sp_model.PieceToId(token)
64
+
65
+ def convert_id_to_token(self, index):
66
+ """Converts an index (integer) in a token (str) using the vocab."""
67
+ if index in self.index_special_tokens or index in [self.eos_id, self.bos_id, self.pad_id] or index < 0:
68
+ return ""
69
+ return self.sp_model.IdToPiece(index)
70
+
71
+ class GEBTokenizer(PreTrainedTokenizer):
72
+ """SentencePieceTokenizer-Megatron wrapper"""
73
+ vocab_files_names = {"vocab_file": "GEBtokenizer.model"}
74
+ model_input_names = ["input_ids", "attention_mask", "position_ids"]
75
+ def __init__(self, vocab_file, padding_side="left", clean_up_tokenization_spaces=False, **kwargs):
76
+
77
+
78
+ self.name = 'GEBTokenizer'
79
+ self.vocab_file = vocab_file
80
+ self.tokenizer = SPTokenizer(vocab_file)
81
+ self.special_tokens = {
82
+ "<bos>": self.tokenizer.bos_id,
83
+ "<eos>": self.tokenizer.eos_id,
84
+ "<pad>": self.tokenizer.pad_id
85
+ }
86
+ super().__init__(padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces, **kwargs)
87
+
88
+ def get_command(self, token):
89
+ if token in self.special_tokens:
90
+ return self.special_tokens[token]
91
+ assert token in self.tokenizer.special_tokens, f"{token} is not a special token for {self.name}"
92
+ return self.tokenizer.special_tokens[token]
93
+
94
+ # def tokenize(self, text):
95
+ # return self.tokenizer.encode(text, bos=True, eos=False
96
+ # )
97
+
98
+ # def detokenize(self, ids):
99
+ # return self.tokenizer.decode(ids)
100
+
101
+ # def _convert_token_to_id(self, token):
102
+ # """ Converts a token (str) in an id using the vocab. """
103
+ # return self.tokenizer.convert_token_to_id(token)
104
+
105
+ # def _convert_id_to_token(self, index):
106
+ # """Converts an index (integer) in a token (str) using the vocab."""
107
+ # return self.tokenizer.convert_id_to_token(index)
108
+
109
+ @property
110
+ def eos_token(self) -> str:
111
+ return "<eos>"
112
+
113
+ @property
114
+ def bos_token(self) -> str:
115
+ return "<bos>"
116
+
117
+ @property
118
+ def eod_token(self) -> str:
119
+ return "<eod>"
120
+
121
+ @property
122
+ def pad_token_id(self):
123
+ return self.get_command("<pad>")
124
+
125
+ @property
126
+ def bos_token_id(self) -> str:
127
+ return self.get_command("<bos>")
128
+
129
+ @property
130
+ def eos_token_id(self):
131
+ return self.get_command("<eos>")
132
+
133
+ @property
134
+ def eod_token_id(self):
135
+ return self.get_command("<eod>")
136
+
137
+ @property
138
+ def vocab_size(self):
139
+ return self.tokenizer.n_words
140
+
141
+ def get_vocab(self):
142
+ """ Returns vocab as a dict """
143
+ vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
144
+ vocab.update(self.added_tokens_encoder)
145
+ return vocab
146
+
147
+ def _tokenize(self, text, **kwargs):
148
+ return self.tokenizer.tokenize(text)
149
+
150
+ def _convert_token_to_id(self, token):
151
+ """ Converts a token (str) in an id using the vocab. """
152
+ return self.tokenizer.convert_token_to_id(token)
153
+
154
+ def _convert_id_to_token(self, index):
155
+ """Converts an index (integer) in a token (str) using the vocab."""
156
+ return self.tokenizer.convert_id_to_token(index)
157
+
158
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
159
+ return self.tokenizer.decode_tokens(tokens)
160
+
161
+ def save_vocabulary(self, save_directory, filename_prefix=None):
162
+ """
163
+ Save the vocabulary and special tokens file to a directory.
164
+
165
+ Args:
166
+ save_directory (`str`):
167
+ The directory in which to save the vocabulary.
168
+ filename_prefix (`str`, *optional*):
169
+ An optional prefix to add to the named of the saved files.
170
+
171
+ Returns:
172
+ `Tuple(str)`: Paths to the files saved.
173
+ """
174
+ if os.path.isdir(save_directory):
175
+ vocab_file = os.path.join(
176
+ save_directory, self.vocab_files_names["vocab_file"]
177
+ )
178
+ else:
179
+ vocab_file = save_directory
180
+
181
+ with open(self.vocab_file, 'rb') as fin:
182
+ proto_str = fin.read()
183
+
184
+ with open(vocab_file, "wb") as writer:
185
+ writer.write(proto_str)
186
+
187
+ return (vocab_file,)
188
+
189
+ def get_prefix_tokens(self):
190
+ prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
191
+ return prefix_tokens
192
+
193
+ def build_finetune_prompt(self, query, history=None):
194
+ if history is None:
195
+ history = []
196
+ prompt = ""
197
+ prompt += "问题:{}\n\n回答:".format(query)
198
+ return prompt
199
+
200
+ def build_single_message(self, message):
201
+ role_tokens = [self.get_command("<eos>")]
202
+ message_tokens = self.tokenizer.encode(message)
203
+ tokens = role_tokens + message_tokens
204
+ return tokens
205
+
206
+ def build_chat_input(self, query, history=None):
207
+ if history is None:
208
+ history = []
209
+ input_ids = []
210
+ input_ids.extend(self.build_single_message(query))
211
+ return self.batch_encode_plus([input_ids], return_tensors="pt", is_split_into_words=True)
212
+
213
+ def build_prompt(self, query, history=None):
214
+ if history is None:
215
+ history = []
216
+ prompt = query
217
+ return prompt
218
+
219
+ def _pad(
220
+ self,
221
+ encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
222
+ max_length: Optional[int] = None,
223
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
224
+ pad_to_multiple_of: Optional[int] = None,
225
+ return_attention_mask: Optional[bool] = None,
226
+ ) -> dict:
227
+ """
228
+ Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
229
+
230
+ Args:
231
+ encoded_inputs:
232
+ Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
233
+ max_length: maximum length of the returned list and optionally padding length (see below).
234
+ Will truncate by taking into account the special tokens.
235
+ padding_strategy: PaddingStrategy to use for padding.
236
+
237
+ - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
238
+ - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
239
+ - PaddingStrategy.DO_NOT_PAD: Do not pad
240
+ The tokenizer padding sides are defined in self.padding_side:
241
+
242
+ - 'left': pads on the left of the sequences
243
+ - 'right': pads on the right of the sequences
244
+ pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
245
+ This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
246
+ `>= 7.5` (Volta).
247
+ return_attention_mask:
248
+ (optional) Set to False to avoid returning attention mask (default: set to model specifics)
249
+ """
250
+ # Load from model defaults
251
+ assert self.padding_side == "left"
252
+
253
+ required_input = encoded_inputs[self.model_input_names[0]]
254
+ seq_length = len(required_input)
255
+
256
+ if padding_strategy == PaddingStrategy.LONGEST:
257
+ max_length = len(required_input)
258
+
259
+ if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
260
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
261
+
262
+ needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
263
+
264
+ # Initialize attention mask if not present.
265
+ if "attention_mask" not in encoded_inputs:
266
+ encoded_inputs["attention_mask"] = [1] * seq_length
267
+
268
+ if "position_ids" not in encoded_inputs:
269
+ encoded_inputs["position_ids"] = list(range(seq_length))
270
+
271
+ if needs_to_be_padded:
272
+ difference = max_length - len(required_input)
273
+
274
+ if "attention_mask" in encoded_inputs:
275
+ encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
276
+ if "position_ids" in encoded_inputs:
277
+ encoded_inputs["position_ids"] = [0] * difference + encoded_inputs["position_ids"]
278
+ encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
279
+
280
+ return encoded_inputs
tokenizer_config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name_or_path": "",
3
+ "remove_space": false,
4
+ "do_lower_case": false,
5
+ "tokenizer_class": "GEBTokenizer",
6
+ "auto_map": {
7
+ "AutoTokenizer": [
8
+ "tokenization_geb.GEBTokenizer",
9
+ null
10
+ ]
11
+ }
12
+ }