Zihui Ren commited on
Commit
578225f
1 Parent(s): dc9fef8

upload models

Browse files
config.json ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "model_repos/Qwen-7B-QAnything/",
3
+ "activation": "swiglu",
4
+ "apply_residual_connection_post_layernorm": false,
5
+ "architectures": [
6
+ "QWenLMHeadModel"
7
+ ],
8
+ "attn_pdrop": 0.0,
9
+ "auto_map": {
10
+ "AutoConfig": "configuration_qwen.QWenConfig",
11
+ "AutoModel": "modeling_qwen.QWenLMHeadModel",
12
+ "AutoModelForCausalLM": "modeling_qwen.QWenLMHeadModel"
13
+ },
14
+ "bf16": false,
15
+ "bias_dropout_fusion": true,
16
+ "bos_token_id": 151643,
17
+ "embd_pdrop": 0.1,
18
+ "eos_token_id": 151643,
19
+ "ffn_hidden_size": 22016,
20
+ "fp16": false,
21
+ "initializer_range": 0.02,
22
+ "kv_channels": 128,
23
+ "layer_norm_epsilon": 1e-06,
24
+ "model_type": "qwen",
25
+ "n_embd": 4096,
26
+ "n_head": 32,
27
+ "n_inner": null,
28
+ "n_layer": 32,
29
+ "n_positions": 8192,
30
+ "no_bias": true,
31
+ "onnx_safe": null,
32
+ "padded_vocab_size": 151936,
33
+ "params_dtype": "torch.bfloat16",
34
+ "pos_emb": "rotary",
35
+ "resid_pdrop": 0.1,
36
+ "rotary_emb_base": 10000,
37
+ "rotary_pct": 1.0,
38
+ "scale_attn_weights": true,
39
+ "seq_length": 8192,
40
+ "tie_word_embeddings": false,
41
+ "tokenizer_type": "QWenTokenizer",
42
+ "torch_dtype": "bfloat16",
43
+ "transformers_version": "4.31.0",
44
+ "use_cache": true,
45
+ "use_dynamic_ntk": true,
46
+ "use_flash_attn": true,
47
+ "use_logn_attn": true,
48
+ "vocab_size": 151936
49
+ }
configuration_qwen.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba Cloud.
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from transformers import PretrainedConfig
7
+
8
+
9
+ class QWenConfig(PretrainedConfig):
10
+ model_type = "qwen"
11
+ keys_to_ignore_at_inference = ["past_key_values"]
12
+ attribute_map = {
13
+ "hidden_size": "n_embd",
14
+ "num_attention_heads": "n_head",
15
+ "max_position_embeddings": "n_positions",
16
+ "num_hidden_layers": "n_layer",
17
+ }
18
+
19
+ def __init__(
20
+ self,
21
+ vocab_size=151851,
22
+ n_embd=4096,
23
+ n_layer=32,
24
+ n_head=32,
25
+ n_inner=None,
26
+ embd_pdrop=0.0,
27
+ attn_pdrop=0.0,
28
+ layer_norm_epsilon=1e-5,
29
+ initializer_range=0.02,
30
+ scale_attn_weights=True,
31
+ use_cache=True,
32
+ eos_token_id=151643,
33
+ apply_residual_connection_post_layernorm=False,
34
+ bf16=True,
35
+ kv_channels=128,
36
+ rotary_pct=1.0,
37
+ rotary_emb_base=10000,
38
+ use_dynamic_ntk=False,
39
+ use_logn_attn=False,
40
+ use_flash_attn=True,
41
+ ffn_hidden_size=22016,
42
+ no_bias=True,
43
+ tie_word_embeddings=False,
44
+ **kwargs,
45
+ ):
46
+ self.eos_token_id = eos_token_id
47
+ super().__init__(
48
+ eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs
49
+ )
50
+
51
+ self.vocab_size = vocab_size
52
+ self.n_embd = n_embd
53
+ self.n_layer = n_layer
54
+ self.n_head = n_head
55
+ self.n_inner = n_inner
56
+ self.embd_pdrop = embd_pdrop
57
+ self.attn_pdrop = attn_pdrop
58
+ self.layer_norm_epsilon = layer_norm_epsilon
59
+ self.initializer_range = initializer_range
60
+ self.scale_attn_weights = scale_attn_weights
61
+ self.use_cache = use_cache
62
+ self.apply_residual_connection_post_layernorm = (
63
+ apply_residual_connection_post_layernorm
64
+ )
65
+ self.bf16 = bf16
66
+ self.kv_channels = kv_channels
67
+ self.rotary_pct = rotary_pct
68
+ self.rotary_emb_base = rotary_emb_base
69
+ self.use_dynamic_ntk = use_dynamic_ntk
70
+ self.use_logn_attn = use_logn_attn
71
+ self.use_flash_attn = use_flash_attn
72
+ self.ffn_hidden_size = ffn_hidden_size
73
+ self.no_bias = no_bias
74
+ self.tie_word_embeddings = tie_word_embeddings
generation_config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "do_sample": true,
4
+ "temperature": 0.6,
5
+ "top_p": 0.8,
6
+ "top_k": 0,
7
+ "repetition_penalty": 1.05,
8
+ "max_new_tokens": 512,
9
+ "bos_token_id": 151643,
10
+ "eos_token_id": 151643,
11
+ "transformers_version": "4.31.0"
12
+ }
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
modeling_qwen.py ADDED
@@ -0,0 +1,1082 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba Cloud.
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import importlib
7
+ import logging as log_print
8
+ import math
9
+ from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ import torch.utils.checkpoint
14
+ from torch.cuda.amp import autocast
15
+
16
+ from torch.nn import CrossEntropyLoss
17
+ from transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList
18
+ from transformers.generation.logits_process import LogitsProcessorList
19
+
20
+ if TYPE_CHECKING:
21
+ from transformers.generation.streamers import BaseStreamer
22
+ from transformers.generation.utils import GenerateOutput
23
+ from transformers.modeling_outputs import (
24
+ BaseModelOutputWithPast,
25
+ CausalLMOutputWithPast,
26
+ )
27
+ from transformers.modeling_utils import PreTrainedModel
28
+ from transformers.utils import logging
29
+
30
+ try:
31
+ from einops import rearrange
32
+ except ImportError:
33
+ rearrange = None
34
+ from torch import nn
35
+
36
+ try:
37
+ from flash_attn.layers.rotary import apply_rotary_emb_func
38
+ from einops import rearrange
39
+
40
+ use_flash_rotary = True
41
+ except ImportError:
42
+ use_flash_rotary = False
43
+ print(
44
+ "Warning: import flash_attn rotary fail, please install FlashAttention rotary to get better performance "
45
+ "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary"
46
+ )
47
+
48
+ try:
49
+ from flash_attn.ops.rms_norm import rms_norm
50
+ except ImportError:
51
+ rms_norm = None
52
+ print(
53
+ "Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get better performance "
54
+ "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm"
55
+ )
56
+
57
+ from .configuration_qwen import QWenConfig
58
+ from .qwen_generation_utils import (
59
+ HistoryType,
60
+ make_context,
61
+ decode_tokens,
62
+ get_stop_words_ids,
63
+ StopWordsLogitsProcessor,
64
+ )
65
+
66
+
67
+ log_print.basicConfig(level=log_print.DEBUG)
68
+
69
+
70
+ logger = logging.get_logger(__name__)
71
+
72
+ _CHECKPOINT_FOR_DOC = "qwen"
73
+ _CONFIG_FOR_DOC = "QWenConfig"
74
+
75
+ QWen_PRETRAINED_MODEL_ARCHIVE_LIST = ["qwen-7b"]
76
+
77
+ try:
78
+ from flash_attn.flash_attn_interface import flash_attn_unpadded_func
79
+ except ImportError:
80
+ flash_attn_unpadded_func = None
81
+ print(
82
+ "Warning: import flash_attn fail, please install FlashAttention "
83
+ "https://github.com/Dao-AILab/flash-attention"
84
+ )
85
+
86
+
87
+ class FlashSelfAttention(torch.nn.Module):
88
+ def __init__(
89
+ self,
90
+ causal=False,
91
+ softmax_scale=None,
92
+ attention_dropout=0.0,
93
+ ):
94
+ super().__init__()
95
+ assert flash_attn_unpadded_func is not None, (
96
+ "Please install FlashAttention first, " "e.g., with pip install flash-attn"
97
+ )
98
+ assert (
99
+ rearrange is not None
100
+ ), "Please install einops first, e.g., with pip install einops"
101
+ self.causal = causal
102
+ self.softmax_scale = softmax_scale
103
+ self.dropout_p = attention_dropout
104
+
105
+ def forward(self, q, k, v):
106
+ assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q, k, v)))
107
+ assert all((i.is_cuda for i in (q, k, v)))
108
+ batch_size, seqlen_q = q.shape[0], q.shape[1]
109
+ seqlen_k = k.shape[1]
110
+ q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
111
+ cu_seqlens_q = torch.arange(
112
+ 0,
113
+ (batch_size + 1) * seqlen_q,
114
+ step=seqlen_q,
115
+ dtype=torch.int32,
116
+ device=q.device,
117
+ )
118
+
119
+ if self.training:
120
+ assert seqlen_k == seqlen_q
121
+
122
+ is_causal = self.causal
123
+ cu_seqlens_k = cu_seqlens_q
124
+ else:
125
+ is_causal = seqlen_q == seqlen_k
126
+ cu_seqlens_k = torch.arange(
127
+ 0,
128
+ (batch_size + 1) * seqlen_k,
129
+ step=seqlen_k,
130
+ dtype=torch.int32,
131
+ device=q.device,
132
+ )
133
+ self.dropout_p = 0
134
+ output = flash_attn_unpadded_func(
135
+ q,
136
+ k,
137
+ v,
138
+ cu_seqlens_q,
139
+ cu_seqlens_k,
140
+ seqlen_q,
141
+ seqlen_k,
142
+ self.dropout_p,
143
+ softmax_scale=self.softmax_scale,
144
+ causal=is_causal,
145
+ )
146
+
147
+ output = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
148
+ return output
149
+
150
+
151
+ class QWenAttention(nn.Module):
152
+ def __init__(self, config, layer_number=None):
153
+ super().__init__()
154
+
155
+ max_positions = config.max_position_embeddings
156
+ self.register_buffer(
157
+ "bias",
158
+ torch.tril(
159
+ torch.ones((max_positions, max_positions), dtype=torch.bool)
160
+ ).view(1, 1, max_positions, max_positions),
161
+ persistent=False,
162
+ )
163
+ self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
164
+ self.layer_number = max(1, layer_number)
165
+ self.params_dtype = config.params_dtype
166
+ self.seq_length = config.seq_length
167
+
168
+ self.hidden_size = config.hidden_size
169
+ self.split_size = config.hidden_size
170
+ self.num_heads = config.num_attention_heads
171
+ self.head_dim = self.hidden_size // self.num_heads
172
+
173
+ self.use_flash_attn = config.use_flash_attn
174
+ self.scale_attn_weights = True
175
+
176
+ self.layer_idx = None
177
+
178
+ self.projection_size = config.kv_channels * config.num_attention_heads
179
+
180
+ assert self.projection_size % config.num_attention_heads == 0
181
+ self.hidden_size_per_attention_head = (
182
+ self.projection_size // config.num_attention_heads
183
+ )
184
+
185
+ self.c_attn = nn.Linear(config.hidden_size, 3 * self.projection_size)
186
+
187
+ self.c_proj = nn.Linear(
188
+ config.hidden_size, self.projection_size, bias=not config.no_bias
189
+ )
190
+
191
+ self.is_fp32 = not (config.bf16 or config.fp16)
192
+ if (
193
+ self.use_flash_attn
194
+ and flash_attn_unpadded_func is not None
195
+ and not self.is_fp32
196
+ ):
197
+ self.core_attention_flash = FlashSelfAttention(
198
+ causal=True, attention_dropout=config.attn_pdrop
199
+ )
200
+
201
+ self.bf16 = config.bf16
202
+
203
+ if config.rotary_pct == 1.0:
204
+ self.rotary_ndims = None
205
+ else:
206
+ assert config.rotary_pct < 1
207
+ self.rotary_ndims = int(
208
+ self.hidden_size_per_attention_head * config.rotary_pct
209
+ )
210
+ dim = (
211
+ self.rotary_ndims
212
+ if self.rotary_ndims is not None
213
+ else self.hidden_size_per_attention_head
214
+ )
215
+ self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base)
216
+
217
+ self.use_dynamic_ntk = config.use_dynamic_ntk
218
+ self.use_logn_attn = config.use_logn_attn
219
+
220
+ logn_list = [
221
+ math.log(i, self.seq_length) if i > self.seq_length else 1
222
+ for i in range(1, 32768)
223
+ ]
224
+ self.logn_tensor = torch.Tensor(logn_list)[None, :, None, None]
225
+ self._ntk_cached = 1.0
226
+
227
+ self.attn_dropout = nn.Dropout(config.attn_pdrop)
228
+
229
+ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
230
+ attn_weights = torch.matmul(query, key.transpose(-1, -2))
231
+
232
+ if self.scale_attn_weights:
233
+ attn_weights = attn_weights / torch.full(
234
+ [],
235
+ value.size(-1) ** 0.5,
236
+ dtype=attn_weights.dtype,
237
+ device=attn_weights.device,
238
+ )
239
+
240
+ query_length, key_length = query.size(-2), key.size(-2)
241
+ causal_mask = self.bias[
242
+ :, :, key_length - query_length : key_length, :key_length
243
+ ]
244
+ mask_value = torch.finfo(attn_weights.dtype).min
245
+ mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(
246
+ attn_weights.device
247
+ )
248
+ attn_weights = torch.where(
249
+ causal_mask, attn_weights.to(attn_weights.dtype), mask_value
250
+ )
251
+
252
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
253
+
254
+ attn_weights = attn_weights.type(value.dtype)
255
+ attn_weights = self.attn_dropout(attn_weights)
256
+
257
+ if head_mask is not None:
258
+ attn_weights = attn_weights * head_mask
259
+
260
+ attn_output = torch.matmul(attn_weights, value)
261
+ attn_output = attn_output.transpose(1, 2)
262
+
263
+ return attn_output, attn_weights
264
+
265
+ def _upcast_and_reordered_attn(
266
+ self, query, key, value, attention_mask=None, head_mask=None
267
+ ):
268
+ bsz, num_heads, q_seq_len, dk = query.size()
269
+ _, _, k_seq_len, _ = key.size()
270
+
271
+ attn_weights = torch.empty(
272
+ bsz * num_heads,
273
+ q_seq_len,
274
+ k_seq_len,
275
+ dtype=torch.float32,
276
+ device=query.device,
277
+ )
278
+
279
+ scale_factor = 1.0
280
+ if self.scale_attn_weights:
281
+ scale_factor /= float(value.size(-1)) ** 0.5
282
+
283
+ with autocast(enabled=False):
284
+ q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(
285
+ -1, dk, k_seq_len
286
+ )
287
+ attn_weights = torch.baddbmm(
288
+ attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor
289
+ )
290
+ attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
291
+
292
+ query_length, key_length = query.size(-2), key.size(-2)
293
+ causal_mask = self.bias[
294
+ :, :, key_length - query_length : key_length, :key_length
295
+ ]
296
+ mask_value = torch.finfo(attn_weights.dtype).min
297
+ mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(
298
+ attn_weights.device
299
+ )
300
+ attn_weights = torch.where(causal_mask, attn_weights, mask_value)
301
+
302
+ if attention_mask is not None:
303
+ attn_weights = attn_weights + attention_mask
304
+
305
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
306
+
307
+ if attn_weights.dtype != torch.float32:
308
+ raise RuntimeError(
309
+ "Error with upcasting, attn_weights does not have dtype torch.float32"
310
+ )
311
+ attn_weights = attn_weights.type(value.dtype)
312
+ attn_weights = self.attn_dropout(attn_weights)
313
+
314
+ if head_mask is not None:
315
+ attn_weights = attn_weights * head_mask
316
+
317
+ attn_output = torch.matmul(attn_weights, value)
318
+
319
+ return attn_output, attn_weights
320
+
321
+ def _split_heads(self, tensor, num_heads, attn_head_size):
322
+ new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
323
+ tensor = tensor.view(new_shape)
324
+ return tensor
325
+
326
+ def _merge_heads(self, tensor, num_heads, attn_head_size):
327
+ tensor = tensor.contiguous()
328
+ new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
329
+ return tensor.view(new_shape)
330
+
331
+ def forward(
332
+ self,
333
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
334
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
335
+ attention_mask: Optional[torch.FloatTensor] = None,
336
+ head_mask: Optional[torch.FloatTensor] = None,
337
+ encoder_hidden_states: Optional[torch.Tensor] = None,
338
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
339
+ output_attentions: Optional[bool] = False,
340
+ use_cache: Optional[bool] = False,
341
+ ):
342
+
343
+ mixed_x_layer = self.c_attn(hidden_states)
344
+ query, key, value = mixed_x_layer.split(self.split_size, dim=2)
345
+
346
+ query = self._split_heads(query, self.num_heads, self.head_dim)
347
+ key = self._split_heads(key, self.num_heads, self.head_dim)
348
+ value = self._split_heads(value, self.num_heads, self.head_dim)
349
+
350
+ kv_seq_len = hidden_states.size()[1]
351
+ if layer_past:
352
+ # layer past[0] shape: bs * seq_len * head_num * dim
353
+ kv_seq_len += layer_past[0].shape[1]
354
+ if (
355
+ self.use_dynamic_ntk
356
+ and kv_seq_len == hidden_states.size()[1]
357
+ and not self.training
358
+ ):
359
+ context_value = math.log(kv_seq_len / self.seq_length, 2) + 1
360
+ ntk_alpha = 2 ** math.ceil(context_value) - 1
361
+ ntk_alpha = max(ntk_alpha, 1)
362
+ self._ntk_cached = ntk_alpha
363
+
364
+ else:
365
+ ntk_alpha = self._ntk_cached
366
+
367
+ rotary_pos_emb = self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha).to(
368
+ hidden_states.device
369
+ )
370
+
371
+ if rotary_pos_emb is not None:
372
+ if isinstance(rotary_pos_emb, tuple):
373
+ rotary_pos_emb = rotary_pos_emb
374
+ else:
375
+ rotary_pos_emb = (rotary_pos_emb,) * 2
376
+
377
+ if rotary_pos_emb is not None:
378
+ q_pos_emb, k_pos_emb = rotary_pos_emb
379
+ # Slice the pos emb for current inference
380
+ cur_len = query.shape[1]
381
+ q_pos_emb = q_pos_emb[:, -cur_len:, :, :]
382
+ k_pos_emb = k_pos_emb[:, -cur_len:, :, :]
383
+ query = apply_rotary_pos_emb(query, q_pos_emb)
384
+ key = apply_rotary_pos_emb(key, k_pos_emb)
385
+
386
+ if layer_past is not None:
387
+ past_key, past_value = layer_past[0], layer_past[1]
388
+ key = torch.cat((past_key, key), dim=1)
389
+ value = torch.cat((past_value, value), dim=1)
390
+
391
+ if use_cache:
392
+ present = (key, value)
393
+ else:
394
+ present = None
395
+
396
+ if self.use_logn_attn and not self.training:
397
+ if self.logn_tensor.device != query.device:
398
+ self.logn_tensor = self.logn_tensor.to(query.device).type_as(query)
399
+ seq_start = key.size(1) - query.size(1)
400
+ seq_end = key.size(1)
401
+ logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :]
402
+ query = query * logn_tensor.expand_as(query)
403
+
404
+ if (
405
+ self.use_flash_attn
406
+ and flash_attn_unpadded_func is not None
407
+ and not self.is_fp32
408
+ and query.is_cuda
409
+ ):
410
+ q, k, v = query, key, value
411
+ context_layer = self.core_attention_flash(q, k, v)
412
+
413
+ context_layer = rearrange(
414
+ context_layer, "b s h d -> b s (h d)"
415
+ ).contiguous()
416
+ else:
417
+ query = query.permute(0, 2, 1, 3)
418
+ key = key.permute(0, 2, 1, 3)
419
+ value = value.permute(0, 2, 1, 3)
420
+ attn_output, attn_weight = self._attn(
421
+ query, key, value, attention_mask, head_mask
422
+ )
423
+ context_layer = self._merge_heads(
424
+ attn_output, self.num_heads, self.head_dim
425
+ )
426
+
427
+ attn_output = self.c_proj(context_layer)
428
+ outputs = (attn_output, present)
429
+ if output_attentions:
430
+ if (
431
+ self.use_flash_attn
432
+ and flash_attn_unpadded_func is not None
433
+ and not self.is_fp32
434
+ ):
435
+ raise ValueError("Cannot output attentions while using flash-attn")
436
+ else:
437
+ outputs += (attn_weight,)
438
+
439
+ return outputs
440
+
441
+
442
+ class QWenMLP(nn.Module):
443
+ def __init__(self, config):
444
+ super().__init__()
445
+ self.w1 = nn.Linear(
446
+ config.hidden_size, config.ffn_hidden_size // 2, bias=not config.no_bias
447
+ )
448
+ self.w2 = nn.Linear(
449
+ config.hidden_size, config.ffn_hidden_size // 2, bias=not config.no_bias
450
+ )
451
+ ff_dim_in = config.ffn_hidden_size // 2
452
+ self.c_proj = nn.Linear(ff_dim_in, config.hidden_size, bias=not config.no_bias)
453
+
454
+ def forward(self, hidden_states):
455
+ a1 = self.w1(hidden_states)
456
+ a2 = self.w2(hidden_states)
457
+ intermediate_parallel = a1 * F.silu(a2)
458
+ output = self.c_proj(intermediate_parallel)
459
+ return output
460
+
461
+
462
+ class QWenBlock(nn.Module):
463
+ def __init__(self, config, layer_idx=None, num_expert=1):
464
+ super().__init__()
465
+ self.num_expert = num_expert
466
+ self.layer_number = layer_idx
467
+ self.apply_residual_connection_post_layernorm = (
468
+ config.apply_residual_connection_post_layernorm
469
+ )
470
+ hidden_size = config.hidden_size
471
+ self.apply_residual_connection_post_layernorm = (
472
+ config.apply_residual_connection_post_layernorm
473
+ )
474
+ self.bf16 = config.bf16
475
+
476
+ self.ln_1 = RMSNorm(
477
+ hidden_size,
478
+ eps=config.layer_norm_epsilon,
479
+ )
480
+ self.attn = QWenAttention(config, layer_number=layer_idx)
481
+ self.ln_2 = RMSNorm(
482
+ hidden_size,
483
+ eps=config.layer_norm_epsilon,
484
+ )
485
+
486
+ self.mlp = QWenMLP(config)
487
+
488
+ def forward(
489
+ self,
490
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
491
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
492
+ attention_mask: Optional[torch.FloatTensor] = None,
493
+ head_mask: Optional[torch.FloatTensor] = None,
494
+ encoder_hidden_states: Optional[torch.Tensor] = None,
495
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
496
+ use_cache: Optional[bool] = False,
497
+ output_attentions: Optional[bool] = False,
498
+ ):
499
+ layernorm_output = self.ln_1(hidden_states)
500
+
501
+ attn_outputs = self.attn(
502
+ layernorm_output,
503
+ layer_past=layer_past,
504
+ attention_mask=attention_mask,
505
+ head_mask=head_mask,
506
+ use_cache=use_cache,
507
+ output_attentions=output_attentions,
508
+ )
509
+ attn_output = attn_outputs[0]
510
+
511
+ outputs = attn_outputs[1:]
512
+
513
+ if self.apply_residual_connection_post_layernorm:
514
+ residual = layernorm_output
515
+ else:
516
+ residual = hidden_states
517
+ layernorm_input = attn_output + residual
518
+
519
+ layernorm_output = self.ln_2(layernorm_input)
520
+
521
+ if self.apply_residual_connection_post_layernorm:
522
+ residual = layernorm_output
523
+ else:
524
+ residual = layernorm_input
525
+
526
+ mlp_output = self.mlp(layernorm_output)
527
+ hidden_states = residual + mlp_output
528
+
529
+ if use_cache:
530
+ outputs = (hidden_states,) + outputs
531
+ else:
532
+ outputs = (hidden_states,) + outputs[1:]
533
+
534
+ return outputs
535
+
536
+
537
+ class QWenPreTrainedModel(PreTrainedModel):
538
+ config_class = QWenConfig
539
+ base_model_prefix = "transformer"
540
+ is_parallelizable = False
541
+ supports_gradient_checkpointing = True
542
+ _no_split_modules = ["QWenBlock"]
543
+
544
+ def __init__(self, *inputs, **kwargs):
545
+ super().__init__(*inputs, **kwargs)
546
+
547
+ def _init_weights(self, module):
548
+ """Initialize the weights."""
549
+ if isinstance(module, nn.Linear):
550
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
551
+ if module.bias is not None:
552
+ module.bias.data.zero_()
553
+ elif isinstance(module, nn.Embedding):
554
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
555
+ if module.padding_idx is not None:
556
+ module.weight.data[module.padding_idx].zero_()
557
+ elif isinstance(module, RMSNorm):
558
+ module.weight.data.fill_(1.0)
559
+
560
+ for name, p in module.named_parameters():
561
+ if name == "c_proj.weight":
562
+ p.data.normal_(
563
+ mean=0.0,
564
+ std=(
565
+ self.config.initializer_range
566
+ / math.sqrt(2 * self.config.n_layer)
567
+ ),
568
+ )
569
+
570
+ def _set_gradient_checkpointing(self, module, value=False):
571
+ if isinstance(module, QWenModel):
572
+ module.gradient_checkpointing = value
573
+
574
+
575
+ class QWenModel(QWenPreTrainedModel):
576
+ _keys_to_ignore_on_load_missing = ["attn.masked_bias"]
577
+
578
+ def __init__(self, config):
579
+ super().__init__(config)
580
+ self.vocab_size = config.padded_vocab_size
581
+ self.num_hidden_layers = config.num_hidden_layers
582
+ self.embed_dim = config.hidden_size
583
+
584
+ max_sequence_length = config.max_position_embeddings
585
+ self.position_embedding_type = config.pos_emb
586
+ self.gradient_checkpointing = False
587
+
588
+ if self.position_embedding_type == "learned":
589
+ self.wpe = nn.Embedding(max_sequence_length, self.embed_dim)
590
+ self.init_method(self.position_embeddings.weight)
591
+ self._position_embeddings_key = "position_embeddings"
592
+ self.init_method(self.position_embeddings.weight)
593
+ else:
594
+ self.wpe = None
595
+ self._position_embeddings_key = ""
596
+
597
+ self.wte = nn.Embedding(self.vocab_size, self.embed_dim)
598
+
599
+ self.drop = nn.Dropout(config.embd_pdrop)
600
+ self.h = nn.ModuleList(
601
+ [
602
+ QWenBlock(
603
+ config,
604
+ layer_idx=i,
605
+ )
606
+ for i in range(config.num_hidden_layers)
607
+ ]
608
+ )
609
+ self.ln_f = RMSNorm(
610
+ self.embed_dim,
611
+ eps=config.layer_norm_epsilon,
612
+ )
613
+
614
+ self.post_init()
615
+
616
+ def get_input_embeddings(self):
617
+ return self.wte
618
+
619
+ def set_input_embeddings(self, new_embeddings):
620
+ self.wte = new_embeddings
621
+
622
+ def forward(
623
+ self,
624
+ input_ids: Optional[torch.LongTensor] = None,
625
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
626
+ attention_mask: Optional[torch.FloatTensor] = None,
627
+ token_type_ids: Optional[torch.LongTensor] = None,
628
+ position_ids: Optional[torch.LongTensor] = None,
629
+ head_mask: Optional[torch.FloatTensor] = None,
630
+ inputs_embeds: Optional[torch.FloatTensor] = None,
631
+ encoder_hidden_states: Optional[torch.Tensor] = None,
632
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
633
+ use_cache: Optional[bool] = None,
634
+ output_attentions: Optional[bool] = None,
635
+ output_hidden_states: Optional[bool] = None,
636
+ return_dict: Optional[bool] = None,
637
+ ):
638
+ output_attentions = (
639
+ output_attentions
640
+ if output_attentions is not None
641
+ else self.config.output_attentions
642
+ )
643
+ output_hidden_states = (
644
+ output_hidden_states
645
+ if output_hidden_states is not None
646
+ else self.config.output_hidden_states
647
+ )
648
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
649
+ return_dict = (
650
+ return_dict if return_dict is not None else self.config.use_return_dict
651
+ )
652
+
653
+ if input_ids is not None and inputs_embeds is not None:
654
+ raise ValueError(
655
+ "You cannot specify both input_ids and inputs_embeds at the same time"
656
+ )
657
+ elif input_ids is not None:
658
+ input_shape = input_ids.size()
659
+ input_ids = input_ids.view(-1, input_shape[-1])
660
+ batch_size = input_ids.shape[0]
661
+ elif inputs_embeds is not None:
662
+ input_shape = inputs_embeds.size()[:-1]
663
+ batch_size = inputs_embeds.shape[0]
664
+ else:
665
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
666
+
667
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
668
+
669
+ if token_type_ids is not None:
670
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
671
+ if position_ids is not None:
672
+ position_ids = position_ids.view(-1, input_shape[-1])
673
+
674
+ if past_key_values is None:
675
+ past_length = 0
676
+ past_key_values = tuple([None] * len(self.h))
677
+ else:
678
+ past_length = past_key_values[0][0].size(-2)
679
+
680
+ if position_ids is None:
681
+ position_ids = torch.arange(
682
+ past_length,
683
+ input_shape[-1] + past_length,
684
+ dtype=torch.long,
685
+ device=device,
686
+ )
687
+ position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
688
+
689
+ if attention_mask is not None:
690
+ if batch_size <= 0:
691
+ raise ValueError("batch_size has to be defined and > 0")
692
+ attention_mask = attention_mask.view(batch_size, -1)
693
+ attention_mask = attention_mask[:, None, None, :]
694
+ attention_mask = attention_mask.to(dtype=self.dtype)
695
+ attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
696
+
697
+ encoder_attention_mask = None
698
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
699
+
700
+ if inputs_embeds is None:
701
+ inputs_embeds = self.wte(input_ids)
702
+ hidden_states = inputs_embeds
703
+ if self.wpe is not None:
704
+ position_embeds = self.wpe(position_ids)
705
+ hidden_states = hidden_states + position_embeds
706
+
707
+ hidden_states = self.drop(hidden_states)
708
+ output_shape = input_shape + (hidden_states.size(-1),)
709
+
710
+ if self.gradient_checkpointing and self.training:
711
+ if use_cache:
712
+ logger.warning_once(
713
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
714
+ )
715
+ use_cache = False
716
+
717
+ presents = () if use_cache else None
718
+ all_self_attentions = () if output_attentions else None
719
+ all_hidden_states = () if output_hidden_states else None
720
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
721
+
722
+ if output_hidden_states:
723
+ all_hidden_states = all_hidden_states + (hidden_states,)
724
+
725
+ if self.gradient_checkpointing and self.training:
726
+
727
+ def create_custom_forward(module):
728
+ def custom_forward(*inputs):
729
+ # None for past_key_value
730
+ return module(*inputs, use_cache, output_attentions)
731
+
732
+ return custom_forward
733
+
734
+ outputs = torch.utils.checkpoint.checkpoint(
735
+ create_custom_forward(block),
736
+ hidden_states,
737
+ None,
738
+ attention_mask,
739
+ head_mask[i],
740
+ encoder_hidden_states,
741
+ encoder_attention_mask,
742
+ )
743
+ else:
744
+ outputs = block(
745
+ hidden_states,
746
+ layer_past=layer_past,
747
+ attention_mask=attention_mask,
748
+ head_mask=head_mask[i],
749
+ encoder_hidden_states=encoder_hidden_states,
750
+ encoder_attention_mask=encoder_attention_mask,
751
+ use_cache=use_cache,
752
+ output_attentions=output_attentions,
753
+ )
754
+
755
+ hidden_states = outputs[0]
756
+ if use_cache is True:
757
+ presents = presents + (outputs[2 if output_attentions else 1],)
758
+
759
+ if output_attentions:
760
+ all_self_attentions = all_self_attentions + (outputs[1],)
761
+
762
+ hidden_states = self.ln_f(hidden_states)
763
+ hidden_states = hidden_states.view(output_shape)
764
+
765
+ if not return_dict:
766
+ return tuple(
767
+ v for v in [hidden_states, presents, all_hidden_states] if v is not None
768
+ )
769
+
770
+ return BaseModelOutputWithPast(
771
+ last_hidden_state=hidden_states,
772
+ past_key_values=presents,
773
+ hidden_states=all_hidden_states,
774
+ attentions=all_self_attentions,
775
+ )
776
+
777
+
778
+ class QWenLMHeadModel(QWenPreTrainedModel):
779
+ _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.rotary_emb\.inv_freq"]
780
+ _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.masked_bias"]
781
+
782
+ def __init__(self, config):
783
+ super().__init__(config)
784
+ self.transformer = QWenModel(config)
785
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
786
+ assert not (
787
+ config.bf16 and config.fp16
788
+ ), "In config, bf16 and fp16 cannot both be true"
789
+ if config.bf16:
790
+ self.transformer.bfloat16()
791
+ self.lm_head.bfloat16()
792
+ if config.fp16:
793
+ self.transformer.half()
794
+ self.lm_head.half()
795
+ self.post_init()
796
+
797
+ def get_output_embeddings(self):
798
+ return self.lm_head
799
+
800
+ def set_output_embeddings(self, new_embeddings):
801
+ self.lm_head = new_embeddings
802
+
803
+ def prepare_inputs_for_generation(
804
+ self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
805
+ ):
806
+ token_type_ids = kwargs.get("token_type_ids", None)
807
+ if past_key_values:
808
+ input_ids = input_ids[:, -1].unsqueeze(-1)
809
+ if token_type_ids is not None:
810
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
811
+
812
+ attention_mask = kwargs.get("attention_mask", None)
813
+ position_ids = kwargs.get("position_ids", None)
814
+
815
+ if attention_mask is not None and position_ids is None:
816
+ position_ids = attention_mask.long().cumsum(-1) - 1
817
+ position_ids.masked_fill_(attention_mask == 0, 1)
818
+ if past_key_values:
819
+ position_ids = position_ids[:, -1].unsqueeze(-1)
820
+ else:
821
+ position_ids = None
822
+
823
+ if inputs_embeds is not None and past_key_values is None:
824
+ model_inputs = {"inputs_embeds": inputs_embeds}
825
+ else:
826
+ model_inputs = {"input_ids": input_ids}
827
+
828
+ model_inputs.update(
829
+ {
830
+ "past_key_values": past_key_values,
831
+ "use_cache": kwargs.get("use_cache"),
832
+ "position_ids": position_ids,
833
+ "attention_mask": attention_mask,
834
+ "token_type_ids": token_type_ids,
835
+ }
836
+ )
837
+ return model_inputs
838
+
839
+ def forward(
840
+ self,
841
+ input_ids: Optional[torch.LongTensor] = None,
842
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
843
+ attention_mask: Optional[torch.FloatTensor] = None,
844
+ token_type_ids: Optional[torch.LongTensor] = None,
845
+ position_ids: Optional[torch.LongTensor] = None,
846
+ head_mask: Optional[torch.FloatTensor] = None,
847
+ inputs_embeds: Optional[torch.FloatTensor] = None,
848
+ encoder_hidden_states: Optional[torch.Tensor] = None,
849
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
850
+ labels: Optional[torch.LongTensor] = None,
851
+ use_cache: Optional[bool] = None,
852
+ output_attentions: Optional[bool] = None,
853
+ output_hidden_states: Optional[bool] = None,
854
+ return_dict: Optional[bool] = None,
855
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
856
+
857
+ return_dict = (
858
+ return_dict if return_dict is not None else self.config.use_return_dict
859
+ )
860
+
861
+ transformer_outputs = self.transformer(
862
+ input_ids,
863
+ past_key_values=past_key_values,
864
+ attention_mask=attention_mask,
865
+ token_type_ids=token_type_ids,
866
+ position_ids=position_ids,
867
+ head_mask=head_mask,
868
+ inputs_embeds=inputs_embeds,
869
+ encoder_hidden_states=encoder_hidden_states,
870
+ encoder_attention_mask=encoder_attention_mask,
871
+ use_cache=use_cache,
872
+ output_attentions=output_attentions,
873
+ output_hidden_states=output_hidden_states,
874
+ return_dict=return_dict,
875
+ )
876
+ hidden_states = transformer_outputs[0]
877
+
878
+ lm_logits = self.lm_head(hidden_states)
879
+
880
+ loss = None
881
+ if labels is not None:
882
+ labels = labels.to(lm_logits.device)
883
+ shift_logits = lm_logits[..., :-1, :].contiguous()
884
+ shift_labels = labels[..., 1:].contiguous()
885
+ loss_fct = CrossEntropyLoss()
886
+ loss = loss_fct(
887
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
888
+ )
889
+
890
+ if not return_dict:
891
+ output = (lm_logits,) + transformer_outputs[1:]
892
+ return ((loss,) + output) if loss is not None else output
893
+ if self.training:
894
+ lm_logits=None
895
+
896
+ return CausalLMOutputWithPast(
897
+ loss=loss,
898
+ logits=lm_logits,
899
+ past_key_values=transformer_outputs.past_key_values,
900
+ hidden_states=transformer_outputs.hidden_states,
901
+ attentions=transformer_outputs.attentions,
902
+ )
903
+
904
+ @staticmethod
905
+ def _reorder_cache(
906
+ past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
907
+ ) -> Tuple[Tuple[torch.Tensor]]:
908
+
909
+ return tuple(
910
+ tuple(
911
+ past_state.index_select(0, beam_idx.to(past_state.device))
912
+ for past_state in layer_past
913
+ )
914
+ for layer_past in past_key_values
915
+ )
916
+
917
+ def chat(
918
+ self,
919
+ tokenizer: PreTrainedTokenizer,
920
+ query: str,
921
+ history: Optional[HistoryType],
922
+ system: str = "You are a helpful assistant.",
923
+ append_history: bool = True,
924
+ ) -> Tuple[str, HistoryType]:
925
+
926
+ if history is None:
927
+ history = []
928
+
929
+ raw_text, context_tokens = make_context(
930
+ tokenizer,
931
+ query,
932
+ history=history,
933
+ system=system,
934
+ max_window_size=6144,
935
+ chat_format=self.generation_config.chat_format,
936
+ )
937
+
938
+ stop_words_ids = get_stop_words_ids(
939
+ self.generation_config.chat_format, tokenizer
940
+ )
941
+ input_ids = torch.tensor([context_tokens]).to(self.device)
942
+
943
+ outputs = self.generate(
944
+ input_ids,
945
+ stop_words_ids=stop_words_ids,
946
+ return_dict_in_generate=False,
947
+ )
948
+
949
+ response = decode_tokens(
950
+ outputs[0],
951
+ tokenizer,
952
+ raw_text_len=len(raw_text),
953
+ context_length=len(context_tokens),
954
+ chat_format=self.generation_config.chat_format,
955
+ verbose=False,
956
+ )
957
+
958
+ if append_history:
959
+ history.append((query, response))
960
+
961
+ return response, history
962
+
963
+ def generate(
964
+ self,
965
+ inputs: Optional[torch.Tensor] = None,
966
+ generation_config: Optional[GenerationConfig] = None,
967
+ logits_processor: Optional[LogitsProcessorList] = None,
968
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
969
+ prefix_allowed_tokens_fn: Optional[
970
+ Callable[[int, torch.Tensor], List[int]]
971
+ ] = None,
972
+ synced_gpus: Optional[bool] = None,
973
+ streamer: Optional["BaseStreamer"] = None,
974
+ **kwargs,
975
+ ) -> Union[GenerateOutput, torch.LongTensor]:
976
+ # Process stop_words_ids.
977
+ stop_words_ids = kwargs.pop("stop_words_ids", None)
978
+ if stop_words_ids is None and generation_config is not None:
979
+ stop_words_ids = getattr(generation_config, "stop_words_ids", None)
980
+ if stop_words_ids is None:
981
+ stop_words_ids = getattr(self.generation_config, "stop_words_ids", None)
982
+
983
+ if stop_words_ids is not None:
984
+ stop_words_logits_processor = StopWordsLogitsProcessor(
985
+ stop_words_ids=stop_words_ids,
986
+ eos_token_id=self.generation_config.eos_token_id,
987
+ )
988
+ if logits_processor is None:
989
+ logits_processor = LogitsProcessorList([stop_words_logits_processor])
990
+ else:
991
+ logits_processor.append(stop_words_logits_processor)
992
+
993
+ return super().generate(
994
+ inputs,
995
+ generation_config,
996
+ logits_processor,
997
+ stopping_criteria,
998
+ prefix_allowed_tokens_fn,
999
+ synced_gpus,
1000
+ streamer,
1001
+ **kwargs,
1002
+ )
1003
+
1004
+
1005
+ class RotaryEmbedding(torch.nn.Module):
1006
+ def __init__(self, dim, base=10000):
1007
+ super().__init__()
1008
+ self.dim = dim
1009
+ self.base = base
1010
+ self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
1011
+ if importlib.util.find_spec("einops") is None:
1012
+ raise RuntimeError("einops is required for Rotary Embedding")
1013
+
1014
+ self._rotary_pos_emb_cache = None
1015
+ self._seq_len_cached = 0
1016
+ self._ntk_alpha_cached = 1.0
1017
+
1018
+ def update_rotary_pos_emb_cache(self, max_seq_len, offset=0, ntk_alpha=1.0):
1019
+ seqlen = max_seq_len + offset
1020
+ if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached:
1021
+ base = self.base * ntk_alpha ** (self.dim / (self.dim - 2))
1022
+ self.inv_freq = 1.0 / (
1023
+ base
1024
+ ** (
1025
+ torch.arange(0, self.dim, 2, device=self.inv_freq.device).float()
1026
+ / self.dim
1027
+ )
1028
+ )
1029
+ self._seq_len_cached = seqlen
1030
+ self._ntk_alpha_cached = ntk_alpha
1031
+ seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=torch.float32)
1032
+ freqs = torch.outer(seq, self.inv_freq)
1033
+ emb = torch.cat((freqs, freqs), dim=-1)
1034
+ from einops import rearrange
1035
+
1036
+ self._rotary_pos_emb_cache = rearrange(emb, "n d -> 1 n 1 d")
1037
+
1038
+ def forward(self, max_seq_len, offset=0, ntk_alpha=1.0):
1039
+ self.update_rotary_pos_emb_cache(max_seq_len, offset, ntk_alpha)
1040
+ return self._rotary_pos_emb_cache[:, offset : offset + max_seq_len]
1041
+
1042
+
1043
+ def _rotate_half(x):
1044
+ from einops import rearrange
1045
+
1046
+ x = rearrange(x, "... (j d) -> ... j d", j=2)
1047
+ x1, x2 = x.unbind(dim=-2)
1048
+ return torch.cat((-x2, x1), dim=-1)
1049
+
1050
+
1051
+ def apply_rotary_pos_emb(t, freqs, use_flash_rotary=False):
1052
+ if use_flash_rotary:
1053
+ t_ = t.float()
1054
+ freqs = freqs.squeeze(0).squeeze(1)
1055
+ cos = freqs[:, : freqs.shape[-1] // 2].cos()
1056
+ sin = freqs[:, : freqs.shape[-1] // 2].sin()
1057
+ output = apply_rotary_emb_func(t_, cos, sin).type_as(t)
1058
+ return output
1059
+ else:
1060
+ rot_dim = freqs.shape[-1]
1061
+ t_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:]
1062
+ t_ = t_.float()
1063
+ t_pass_ = t_pass_.float()
1064
+ t_ = (t_ * freqs.cos()) + (_rotate_half(t_) * freqs.sin())
1065
+ return torch.cat((t_, t_pass_), dim=-1).type_as(t)
1066
+
1067
+
1068
+ class RMSNorm(torch.nn.Module):
1069
+ def __init__(self, dim: int, eps: float = 1e-6):
1070
+ super().__init__()
1071
+ self.eps = eps
1072
+ self.weight = nn.Parameter(torch.ones(dim))
1073
+
1074
+ def _norm(self, x):
1075
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
1076
+
1077
+ def forward(self, x):
1078
+ if rms_norm is not None and x.is_cuda:
1079
+ return rms_norm(x, self.weight, self.eps)
1080
+ else:
1081
+ output = self._norm(x.float()).type_as(x)
1082
+ return output * self.weight
pytorch_model-00001-of-00002.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1a5979dce941fc12f841303087367d39fbd3e726227b85330bb2abdb7255781d
3
+ size 9969772092
pytorch_model-00002-of-00002.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:57967253f3d177b73af70eb6c56fc83072c56030f42041f0b20ca5c3be67dbda
3
+ size 5472963479
qwen_generation_utils.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba Cloud.
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """Generation support."""
7
+
8
+ from typing import Tuple, List, Union, Iterable
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from transformers import PreTrainedTokenizer
14
+ from transformers import logging
15
+ from transformers.generation import LogitsProcessor
16
+
17
+ logger = logging.get_logger(__name__)
18
+
19
+ # Types.
20
+ HistoryType = List[Tuple[str, str]]
21
+ TokensType = List[int]
22
+ BatchTokensType = List[List[int]]
23
+
24
+
25
+ def pad_batch(batch: BatchTokensType, pad_id: int, seq_length: int) -> BatchTokensType:
26
+ for tokens in batch:
27
+ context_length = len(tokens)
28
+ if context_length < seq_length:
29
+ tokens.extend([pad_id] * (seq_length - context_length))
30
+ return batch
31
+
32
+
33
+ def get_ltor_masks_and_position_ids(
34
+ data,
35
+ eod_token,
36
+ reset_position_ids,
37
+ reset_attention_mask,
38
+ eod_mask_loss,
39
+ ):
40
+ """Build masks and position id for left to right model."""
41
+
42
+ # Extract batch size and sequence length.
43
+ micro_batch_size, seq_length = data.size()
44
+
45
+ # Attention mask (lower triangular).
46
+ if reset_attention_mask:
47
+ att_mask_batch = micro_batch_size
48
+ else:
49
+ att_mask_batch = 1
50
+ attention_mask = torch.tril(
51
+ torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)
52
+ ).view(att_mask_batch, 1, seq_length, seq_length)
53
+
54
+ # Loss mask.
55
+ loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
56
+ if eod_mask_loss:
57
+ loss_mask[data == eod_token] = 0.0
58
+
59
+ # Position ids.
60
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device)
61
+ position_ids = position_ids.unsqueeze(0).expand_as(data)
62
+ # We need to clone as the ids will be modifed based on batch index.
63
+ if reset_position_ids:
64
+ position_ids = position_ids.clone()
65
+
66
+ if reset_position_ids or reset_attention_mask:
67
+ # Loop through the batches:
68
+ for b in range(micro_batch_size):
69
+
70
+ # Find indecies where EOD token is.
71
+ eod_index = position_ids[b, data[b] == eod_token]
72
+ # Detach indecies from positions if going to modify positions.
73
+ if reset_position_ids:
74
+ eod_index = eod_index.clone()
75
+
76
+ # Loop through EOD indecies:
77
+ prev_index = 0
78
+ for j in range(eod_index.size()[0]):
79
+ i = eod_index[j]
80
+ # Mask attention loss.
81
+ if reset_attention_mask:
82
+ attention_mask[b, 0, (i + 1) :, : (i + 1)] = 0
83
+ # Reset positions.
84
+ if reset_position_ids:
85
+ position_ids[b, (i + 1) :] -= i + 1 - prev_index
86
+ prev_index = i + 1
87
+
88
+ # Convert attention mask to binary:
89
+ attention_mask = attention_mask < 0.5
90
+
91
+ return attention_mask, loss_mask, position_ids
92
+
93
+
94
+ def get_batch(context_tokens: torch.LongTensor, eod_id: int):
95
+ """Generate batch from context tokens."""
96
+ # Move to GPU.
97
+ tokens = context_tokens.contiguous().to(context_tokens.device)
98
+ # Get the attention mask and postition ids.
99
+ attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
100
+ tokens,
101
+ eod_id,
102
+ reset_position_ids=False,
103
+ reset_attention_mask=False,
104
+ eod_mask_loss=False,
105
+ )
106
+ return tokens, attention_mask, position_ids
107
+
108
+
109
+ def get_stop_words_ids(chat_format, tokenizer):
110
+ if chat_format == "raw":
111
+ stop_words_ids = [tokenizer.encode("Human:"), [tokenizer.eod_id]]
112
+ elif chat_format == "chatml":
113
+ stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]]
114
+ else:
115
+ raise NotImplementedError(f"Unknown chat format {chat_format!r}")
116
+ return stop_words_ids
117
+
118
+
119
+ def make_context(
120
+ tokenizer: PreTrainedTokenizer,
121
+ query: str,
122
+ history: List[Tuple[str, str]] = None,
123
+ system: str = "",
124
+ max_window_size: int = 8192,
125
+ chat_format: str = "chatml",
126
+ ):
127
+ if history is None:
128
+ history = []
129
+
130
+ if chat_format == "chatml":
131
+ im_start, im_end = "<|im_start|>", "<|im_end|>"
132
+ im_start_tokens = [tokenizer.im_start_id]
133
+ im_end_tokens = [tokenizer.im_end_id]
134
+ nl_tokens = tokenizer.encode("\n")
135
+
136
+ def _tokenize_str(role, content):
137
+ return f"{role}\n{content}", tokenizer.encode(
138
+ role
139
+ ) + nl_tokens + tokenizer.encode(content)
140
+
141
+ system_text, system_tokens_part = _tokenize_str("system", system)
142
+ system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
143
+
144
+ raw_text = ""
145
+ context_tokens = []
146
+
147
+ for turn_query, turn_response in reversed(history):
148
+ query_text, query_tokens_part = _tokenize_str("user", turn_query)
149
+ query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
150
+ response_text, response_tokens_part = _tokenize_str(
151
+ "assistant", turn_response
152
+ )
153
+ response_tokens = im_start_tokens + response_tokens_part + im_end_tokens
154
+
155
+ next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens
156
+ prev_chat = (
157
+ f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}"
158
+ )
159
+
160
+ current_context_size = (
161
+ len(system_tokens) + len(next_context_tokens) + len(context_tokens)
162
+ )
163
+ if current_context_size < max_window_size:
164
+ context_tokens = next_context_tokens + context_tokens
165
+ raw_text = prev_chat + raw_text
166
+ else:
167
+ break
168
+
169
+ context_tokens = system_tokens + context_tokens
170
+ raw_text = f"{im_start}{system_text}{im_end}" + raw_text
171
+ context_tokens += (
172
+ nl_tokens
173
+ + im_start_tokens
174
+ + _tokenize_str("user", query)[1]
175
+ + im_end_tokens
176
+ + nl_tokens
177
+ + im_start_tokens
178
+ + tokenizer.encode("assistant")
179
+ + nl_tokens
180
+ )
181
+ raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n"
182
+
183
+ elif chat_format == "raw":
184
+ raw_text = query
185
+ context_tokens = tokenizer.encode(raw_text)
186
+ else:
187
+ raise NotImplementedError(f"Unknown chat format {chat_format!r}")
188
+
189
+ return raw_text, context_tokens
190
+
191
+
192
+ def _decode_default(
193
+ tokens: List[int],
194
+ *,
195
+ stop_words: List[str],
196
+ eod_words: List[str],
197
+ tokenizer: PreTrainedTokenizer,
198
+ raw_text_len: int,
199
+ verbose: bool = False,
200
+ return_end_reason: bool = False,
201
+ ):
202
+ trim_decode_tokens = tokenizer.decode(tokens)[raw_text_len:]
203
+ if verbose:
204
+ print("\nRaw Generate: ", trim_decode_tokens)
205
+
206
+ end_reason = f"Gen length {len(tokens)}"
207
+ for stop_word in stop_words:
208
+ trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip()
209
+ for eod_word in eod_words:
210
+ if eod_word in trim_decode_tokens:
211
+ end_reason = f"Gen {eod_word!r}"
212
+ trim_decode_tokens = trim_decode_tokens.split(eod_word)[0]
213
+ trim_decode_tokens = trim_decode_tokens.strip()
214
+ if verbose:
215
+ print("\nEnd Reason:", end_reason)
216
+ print("\nGenerate: ", trim_decode_tokens)
217
+
218
+ if return_end_reason:
219
+ return trim_decode_tokens, end_reason
220
+ else:
221
+ return trim_decode_tokens
222
+
223
+
224
+ def _decode_chatml(
225
+ tokens: List[int],
226
+ *,
227
+ stop_words: List[str],
228
+ eod_token_ids: List[int],
229
+ tokenizer: PreTrainedTokenizer,
230
+ raw_text_len: int,
231
+ context_length: int,
232
+ verbose: bool = False,
233
+ return_end_reason: bool = False,
234
+ ):
235
+ end_reason = f"Gen length {len(tokens)}"
236
+ eod_token_idx = context_length
237
+ for eod_token_idx in range(context_length, len(tokens)):
238
+ if tokens[eod_token_idx] in eod_token_ids:
239
+ end_reason = f"Gen {tokenizer.decode([tokens[eod_token_idx]])!r}"
240
+ break
241
+
242
+ trim_decode_tokens = tokenizer.decode(tokens[:eod_token_idx])[raw_text_len:]
243
+ if verbose:
244
+ print("\nRaw Generate w/o EOD:", tokenizer.decode(tokens)[raw_text_len:])
245
+ print("\nRaw Generate:", trim_decode_tokens)
246
+ print("\nEnd Reason:", end_reason)
247
+ for stop_word in stop_words:
248
+ trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip()
249
+ trim_decode_tokens = trim_decode_tokens.strip()
250
+ if verbose:
251
+ print("\nGenerate:", trim_decode_tokens)
252
+
253
+ if return_end_reason:
254
+ return trim_decode_tokens, end_reason
255
+ else:
256
+ return trim_decode_tokens
257
+
258
+
259
+ def decode_tokens(
260
+ tokens: Union[torch.LongTensor, TokensType],
261
+ tokenizer: PreTrainedTokenizer,
262
+ raw_text_len: int,
263
+ context_length: int,
264
+ chat_format: str,
265
+ verbose: bool = False,
266
+ return_end_reason: bool = False,
267
+ ) -> str:
268
+ if torch.is_tensor(tokens):
269
+ tokens = tokens.cpu().numpy().tolist()
270
+
271
+ if chat_format == "chatml":
272
+ return _decode_chatml(
273
+ tokens,
274
+ stop_words=[],
275
+ eod_token_ids=[tokenizer.im_start_id, tokenizer.im_end_id],
276
+ tokenizer=tokenizer,
277
+ raw_text_len=raw_text_len,
278
+ context_length=context_length,
279
+ verbose=verbose,
280
+ return_end_reason=return_end_reason,
281
+ )
282
+ elif chat_format == "raw":
283
+ return _decode_default(
284
+ tokens,
285
+ stop_words=["<|endoftext|>"],
286
+ eod_words=["<|endoftext|>"],
287
+ tokenizer=tokenizer,
288
+ raw_text_len=raw_text_len,
289
+ verbose=verbose,
290
+ return_end_reason=return_end_reason,
291
+ )
292
+ else:
293
+ raise NotImplementedError(f"Unknown chat format {chat_format!r}")
294
+
295
+
296
+ class StopWordsLogitsProcessor(LogitsProcessor):
297
+ """
298
+ :class:`transformers.LogitsProcessor` that enforces that when specified sequences appear, stop geration.
299
+
300
+ Args:
301
+ stop_words_ids (:obj:`List[List[int]]`):
302
+ List of list of token ids of stop ids. In order to get the tokens of the words
303
+ that should not appear in the generated text, use :obj:`tokenizer(bad_word,
304
+ add_prefix_space=True).input_ids`.
305
+ eos_token_id (:obj:`int`):
306
+ The id of the `end-of-sequence` token.
307
+ """
308
+
309
+ def __init__(self, stop_words_ids: Iterable[Iterable[int]], eos_token_id: int):
310
+
311
+ if not isinstance(stop_words_ids, List) or len(stop_words_ids) == 0:
312
+ raise ValueError(
313
+ f"`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}."
314
+ )
315
+ if any(not isinstance(bad_word_ids, list) for bad_word_ids in stop_words_ids):
316
+ raise ValueError(
317
+ f"`stop_words_ids` has to be a list of lists, but is {stop_words_ids}."
318
+ )
319
+ if any(
320
+ any(
321
+ (not isinstance(token_id, (int, np.integer)) or token_id < 0)
322
+ for token_id in stop_word_ids
323
+ )
324
+ for stop_word_ids in stop_words_ids
325
+ ):
326
+ raise ValueError(
327
+ f"Each list in `stop_words_ids` has to be a list of positive integers, but is {stop_words_ids}."
328
+ )
329
+
330
+ self.stop_words_ids = list(
331
+ filter(
332
+ lambda bad_token_seq: bad_token_seq != [eos_token_id], stop_words_ids
333
+ )
334
+ )
335
+ self.eos_token_id = eos_token_id
336
+ for stop_token_seq in self.stop_words_ids:
337
+ assert (
338
+ len(stop_token_seq) > 0
339
+ ), "Stop words token sequences {} cannot have an empty list".format(
340
+ stop_words_ids
341
+ )
342
+
343
+ def __call__(
344
+ self, input_ids: torch.LongTensor, scores: torch.FloatTensor
345
+ ) -> torch.FloatTensor:
346
+ stopped_samples = self._calc_stopped_samples(input_ids)
347
+ for i, should_stop in enumerate(stopped_samples):
348
+ if should_stop:
349
+ scores[i, self.eos_token_id] = float(2**15)
350
+ return scores
351
+
352
+ def _tokens_match(self, prev_tokens: torch.LongTensor, tokens: List[int]) -> bool:
353
+ if len(tokens) == 0:
354
+ # if bad word tokens is just one token always ban it
355
+ return True
356
+ elif len(tokens) > len(prev_tokens):
357
+ # if bad word tokens are longer then prev input_ids they can't be equal
358
+ return False
359
+ elif prev_tokens[-len(tokens) :].tolist() == tokens:
360
+ # if tokens match
361
+ return True
362
+ else:
363
+ return False
364
+
365
+ def _calc_stopped_samples(self, prev_input_ids: Iterable[int]) -> Iterable[int]:
366
+ stopped_samples = []
367
+ for prev_input_ids_slice in prev_input_ids:
368
+ match = False
369
+ for stop_token_seq in self.stop_words_ids:
370
+ if self._tokens_match(prev_input_ids_slice, stop_token_seq):
371
+ # if tokens do not match continue
372
+ match = True
373
+ break
374
+ stopped_samples.append(match)
375
+
376
+ return stopped_samples
377
+
378
+
379
+ def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")):
380
+ """This function has been mostly taken from huggingface conversational
381
+ ai code at
382
+ https://medium.com/huggingface/how-to-build-a-state-of-the-art-
383
+ conversational-ai-with-transfer-learning-2d818ac26313"""
384
+
385
+ if top_k > 0:
386
+ # Remove all tokens with a probability less than the
387
+ # last token of the top-k
388
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
389
+ logits[indices_to_remove] = filter_value
390
+
391
+ if top_p > 0.0:
392
+ # Cconvert to 1D
393
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
394
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
395
+
396
+ # Remove tokens with cumulative probability above the threshold
397
+ sorted_indices_to_remove = cumulative_probs > top_p
398
+ # Shift the indices to the right to keep also the first token
399
+ # above the threshold
400
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
401
+ sorted_indices_to_remove[..., 0] = 0
402
+ for i in range(sorted_indices.size(0)):
403
+ indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
404
+ logits[i][indices_to_remove] = filter_value
405
+
406
+ return logits
407
+
408
+
409
+ def switch(val1, val2, boolean):
410
+ boolean = boolean.type_as(val1)
411
+ return (1 - boolean) * val1 + boolean * val2
special_tokens_map.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<|endoftext|>",
3
+ "eos_token": "<|endoftext|>",
4
+ "unk_token": "<|endoftext|>"
5
+ }
tokenizer_config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "bos_token": {
5
+ "__type": "AddedToken",
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": true,
9
+ "rstrip": false,
10
+ "single_word": false
11
+ },
12
+ "clean_up_tokenization_spaces": true,
13
+ "eos_token": {
14
+ "__type": "AddedToken",
15
+ "content": "<|endoftext|>",
16
+ "lstrip": false,
17
+ "normalized": true,
18
+ "rstrip": false,
19
+ "single_word": false
20
+ },
21
+ "errors": "replace",
22
+ "model_max_length": 8192,
23
+ "pad_token": null,
24
+ "padding_side": "left",
25
+ "tokenizer_class": "GPT2Tokenizer",
26
+ "unk_token": {
27
+ "__type": "AddedToken",
28
+ "content": "<|endoftext|>",
29
+ "lstrip": false,
30
+ "normalized": true,
31
+ "rstrip": false,
32
+ "single_word": false
33
+ }
34
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff