SnifferCaptain commited on
Commit
b425c8f
Β·
verified Β·
1 Parent(s): 9fb0bc4

Upload 5 files

Browse files
Files changed (5) hide show
  1. config.json +97 -0
  2. model.safetensors +3 -0
  3. tokenizer.json +0 -0
  4. tokenizer_config.json +335 -0
  5. ymodel3_eval.py +755 -0
config.json ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "return_dict": true,
3
+ "output_hidden_states": false,
4
+ "torchscript": false,
5
+ "dtype": "float32",
6
+ "_output_attentions": false,
7
+ "pruned_heads": {},
8
+ "tie_word_embeddings": true,
9
+ "chunk_size_feed_forward": 0,
10
+ "is_encoder_decoder": false,
11
+ "is_decoder": false,
12
+ "cross_attention_hidden_size": null,
13
+ "add_cross_attention": false,
14
+ "tie_encoder_decoder": false,
15
+ "architectures": null,
16
+ "finetuning_task": null,
17
+ "id2label": {
18
+ "0": "LABEL_0",
19
+ "1": "LABEL_1"
20
+ },
21
+ "label2id": {
22
+ "LABEL_0": 0,
23
+ "LABEL_1": 1
24
+ },
25
+ "task_specific_params": null,
26
+ "problem_type": null,
27
+ "tokenizer_class": null,
28
+ "prefix": null,
29
+ "bos_token_id": 151644,
30
+ "pad_token_id": 151643,
31
+ "eos_token_id": 151645,
32
+ "sep_token_id": null,
33
+ "decoder_start_token_id": null,
34
+ "max_length": 20,
35
+ "min_length": 0,
36
+ "do_sample": false,
37
+ "early_stopping": false,
38
+ "num_beams": 1,
39
+ "temperature": 1.0,
40
+ "top_k": 50,
41
+ "top_p": 1.0,
42
+ "typical_p": 1.0,
43
+ "repetition_penalty": 1.0,
44
+ "length_penalty": 1.0,
45
+ "no_repeat_ngram_size": 0,
46
+ "encoder_no_repeat_ngram_size": 0,
47
+ "bad_words_ids": null,
48
+ "num_return_sequences": 1,
49
+ "output_scores": false,
50
+ "return_dict_in_generate": false,
51
+ "forced_bos_token_id": null,
52
+ "forced_eos_token_id": null,
53
+ "remove_invalid_values": false,
54
+ "exponential_decay_length_penalty": null,
55
+ "suppress_tokens": null,
56
+ "begin_suppress_tokens": null,
57
+ "num_beam_groups": 1,
58
+ "diversity_penalty": 0.0,
59
+ "_name_or_path": "",
60
+ "_commit_hash": null,
61
+ "_attn_implementation_internal": "eager",
62
+ "transformers_version": null,
63
+ "tf_legacy_loss": false,
64
+ "use_bfloat16": false,
65
+ "dropout": 0.0,
66
+ "hidden_act": "silu",
67
+ "hidden_size": 768,
68
+ "num_hidden_layers": 8,
69
+ "max_position_embeddings": 4096,
70
+ "vocab_size": 6400,
71
+ "rms_norm_eps": 1e-06,
72
+ "rope_theta": 50000.0,
73
+ "rope_scaling": null,
74
+ "self_distill": true,
75
+ "intermediate_size": 1536,
76
+ "expert_intermediate_size": 768,
77
+ "n_routed_experts": 0,
78
+ "moe_topk": 2,
79
+ "score_func": "softmax",
80
+ "n_shared_experts": 0,
81
+ "top_k_layer_dense": 8,
82
+ "aux_loss_alpha": 0.02,
83
+ "seq_aux": false,
84
+ "norm_topk_prob": true,
85
+ "noisy_expert": 0.0,
86
+ "moe_backend": "compact",
87
+ "router_bias_enabled": true,
88
+ "router_bias_update_rate": 0.001,
89
+ "router_bias_clamp": 5.0,
90
+ "num_heads": 6,
91
+ "mla_kv_lora_rank": 128,
92
+ "mla_qk_nope_head_dim": 64,
93
+ "mla_qk_rope_head_dim": 64,
94
+ "mla_attn_impl": "absorb",
95
+ "qkv_lora": false,
96
+ "torch_dtype": "bfloat16"
97
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4447ad64478d7bdb64c55ade2a9a027b4d7c7ac8ac0420e26474e9f14755f795
3
+ size 110141736
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": false,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "<|endoftext|>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": true
13
+ },
14
+ "1": {
15
+ "content": "<|im_start|>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "2": {
23
+ "content": "<|im_end|>",
24
+ "lstrip": false,
25
+ "normalized": false,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": true
29
+ },
30
+ "3": {
31
+ "content": "<|object_ref_start|>",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false,
36
+ "special": true
37
+ },
38
+ "4": {
39
+ "content": "<|object_ref_end|>",
40
+ "lstrip": false,
41
+ "normalized": false,
42
+ "rstrip": false,
43
+ "single_word": false,
44
+ "special": true
45
+ },
46
+ "5": {
47
+ "content": "<|box_start|>",
48
+ "lstrip": false,
49
+ "normalized": false,
50
+ "rstrip": false,
51
+ "single_word": false,
52
+ "special": true
53
+ },
54
+ "6": {
55
+ "content": "<|box_end|>",
56
+ "lstrip": false,
57
+ "normalized": false,
58
+ "rstrip": false,
59
+ "single_word": false,
60
+ "special": true
61
+ },
62
+ "7": {
63
+ "content": "<|quad_start|>",
64
+ "lstrip": false,
65
+ "normalized": false,
66
+ "rstrip": false,
67
+ "single_word": false,
68
+ "special": true
69
+ },
70
+ "8": {
71
+ "content": "<|quad_end|>",
72
+ "lstrip": false,
73
+ "normalized": false,
74
+ "rstrip": false,
75
+ "single_word": false,
76
+ "special": true
77
+ },
78
+ "9": {
79
+ "content": "<|vision_start|>",
80
+ "lstrip": false,
81
+ "normalized": false,
82
+ "rstrip": false,
83
+ "single_word": false,
84
+ "special": true
85
+ },
86
+ "10": {
87
+ "content": "<|vision_end|>",
88
+ "lstrip": false,
89
+ "normalized": false,
90
+ "rstrip": false,
91
+ "single_word": false,
92
+ "special": true
93
+ },
94
+ "11": {
95
+ "content": "<|vision_pad|>",
96
+ "lstrip": false,
97
+ "normalized": false,
98
+ "rstrip": false,
99
+ "single_word": false,
100
+ "special": true
101
+ },
102
+ "12": {
103
+ "content": "<|image_pad|>",
104
+ "lstrip": false,
105
+ "normalized": false,
106
+ "rstrip": false,
107
+ "single_word": false,
108
+ "special": true
109
+ },
110
+ "13": {
111
+ "content": "<|video_pad|>",
112
+ "lstrip": false,
113
+ "normalized": false,
114
+ "rstrip": false,
115
+ "single_word": false,
116
+ "special": true
117
+ },
118
+ "14": {
119
+ "content": "<|audio_start|>",
120
+ "lstrip": false,
121
+ "normalized": false,
122
+ "rstrip": false,
123
+ "single_word": false,
124
+ "special": true
125
+ },
126
+ "15": {
127
+ "content": "<|audio_end|>",
128
+ "lstrip": false,
129
+ "normalized": false,
130
+ "rstrip": false,
131
+ "single_word": false,
132
+ "special": true
133
+ },
134
+ "16": {
135
+ "content": "<|audio_pad|>",
136
+ "lstrip": false,
137
+ "normalized": false,
138
+ "rstrip": false,
139
+ "single_word": false,
140
+ "special": true
141
+ },
142
+ "17": {
143
+ "content": "<tts_pad>",
144
+ "lstrip": false,
145
+ "normalized": false,
146
+ "rstrip": false,
147
+ "single_word": false,
148
+ "special": true
149
+ },
150
+ "18": {
151
+ "content": "<tts_text_bos>",
152
+ "lstrip": false,
153
+ "normalized": false,
154
+ "rstrip": false,
155
+ "single_word": false,
156
+ "special": true
157
+ },
158
+ "19": {
159
+ "content": "<tts_text_eod>",
160
+ "lstrip": false,
161
+ "normalized": false,
162
+ "rstrip": false,
163
+ "single_word": false,
164
+ "special": true
165
+ },
166
+ "20": {
167
+ "content": "<tts_text_bos_single>",
168
+ "lstrip": false,
169
+ "normalized": false,
170
+ "rstrip": false,
171
+ "single_word": false,
172
+ "special": true
173
+ },
174
+ "21": {
175
+ "content": "<tool_call>",
176
+ "lstrip": false,
177
+ "normalized": false,
178
+ "rstrip": false,
179
+ "single_word": false,
180
+ "special": false
181
+ },
182
+ "22": {
183
+ "content": "</tool_call>",
184
+ "lstrip": false,
185
+ "normalized": false,
186
+ "rstrip": false,
187
+ "single_word": false,
188
+ "special": false
189
+ },
190
+ "23": {
191
+ "content": "<tool_response>",
192
+ "lstrip": false,
193
+ "normalized": false,
194
+ "rstrip": false,
195
+ "single_word": false,
196
+ "special": false
197
+ },
198
+ "24": {
199
+ "content": "</tool_response>",
200
+ "lstrip": false,
201
+ "normalized": false,
202
+ "rstrip": false,
203
+ "single_word": false,
204
+ "special": false
205
+ },
206
+ "25": {
207
+ "content": "<think>",
208
+ "lstrip": false,
209
+ "normalized": false,
210
+ "rstrip": false,
211
+ "single_word": false,
212
+ "special": false
213
+ },
214
+ "26": {
215
+ "content": "</think>",
216
+ "lstrip": false,
217
+ "normalized": false,
218
+ "rstrip": false,
219
+ "single_word": false,
220
+ "special": false
221
+ },
222
+ "27": {
223
+ "content": "<|buffer1|>",
224
+ "lstrip": false,
225
+ "normalized": false,
226
+ "rstrip": false,
227
+ "single_word": false,
228
+ "special": false
229
+ },
230
+ "28": {
231
+ "content": "<|buffer2|>",
232
+ "lstrip": false,
233
+ "normalized": false,
234
+ "rstrip": false,
235
+ "single_word": false,
236
+ "special": false
237
+ },
238
+ "29": {
239
+ "content": "<|buffer3|>",
240
+ "lstrip": false,
241
+ "normalized": false,
242
+ "rstrip": false,
243
+ "single_word": false,
244
+ "special": false
245
+ },
246
+ "30": {
247
+ "content": "<|buffer4|>",
248
+ "lstrip": false,
249
+ "normalized": false,
250
+ "rstrip": false,
251
+ "single_word": false,
252
+ "special": false
253
+ },
254
+ "31": {
255
+ "content": "<|buffer5|>",
256
+ "lstrip": false,
257
+ "normalized": false,
258
+ "rstrip": false,
259
+ "single_word": false,
260
+ "special": false
261
+ },
262
+ "32": {
263
+ "content": "<|buffer6|>",
264
+ "lstrip": false,
265
+ "normalized": false,
266
+ "rstrip": false,
267
+ "single_word": false,
268
+ "special": false
269
+ },
270
+ "33": {
271
+ "content": "<|buffer7|>",
272
+ "lstrip": false,
273
+ "normalized": false,
274
+ "rstrip": false,
275
+ "single_word": false,
276
+ "special": false
277
+ },
278
+ "34": {
279
+ "content": "<|buffer8|>",
280
+ "lstrip": false,
281
+ "normalized": false,
282
+ "rstrip": false,
283
+ "single_word": false,
284
+ "special": false
285
+ },
286
+ "35": {
287
+ "content": "<|buffer9|>",
288
+ "lstrip": false,
289
+ "normalized": false,
290
+ "rstrip": false,
291
+ "single_word": false,
292
+ "special": false
293
+ }
294
+ },
295
+ "additional_special_tokens": [
296
+ "<|im_start|>",
297
+ "<|im_end|>",
298
+ "<|object_ref_start|>",
299
+ "<|object_ref_end|>",
300
+ "<|box_start|>",
301
+ "<|box_end|>",
302
+ "<|quad_start|>",
303
+ "<|quad_end|>",
304
+ "<|vision_start|>",
305
+ "<|vision_end|>",
306
+ "<|vision_pad|>",
307
+ "<|image_pad|>",
308
+ "<|video_pad|>",
309
+ "<|audio_start|>",
310
+ "<|audio_end|>",
311
+ "<|audio_pad|>",
312
+ "<tts_pad>",
313
+ "<tts_text_bos>",
314
+ "<tts_text_eod>",
315
+ "<tts_text_bos_single>"
316
+ ],
317
+ "bos_token": "<|im_start|>",
318
+ "clean_up_tokenization_spaces": false,
319
+ "eos_token": "<|im_end|>",
320
+ "legacy": true,
321
+ "model_max_length": 131072,
322
+ "pad_token": "<|endoftext|>",
323
+ "sp_model_kwargs": {},
324
+ "spaces_between_special_tokens": false,
325
+ "unk_token": "<|endoftext|>",
326
+ "image_token": "<|image_pad|>",
327
+ "audio_token": "<|audio_pad|>",
328
+ "video_token": "<|video_pad|>",
329
+ "vision_bos_token": "<|vision_start|>",
330
+ "vision_eos_token": "<|vision_end|>",
331
+ "audio_bos_token": "<|audio_start|>",
332
+ "audio_eos_token": "<|audio_end|>",
333
+ "chat_template": "{%- if tools %}{{- '<|im_start|>system\\n' }}{%- if messages[0].role == 'system' %}{{- messages[0].content + '\\n\\n' }}{%- endif %}{{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}{%- for tool in tools %}{{- \"\\n\" }}{{- tool | tojson }}{%- endfor %}{{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}{%- else %}{%- if messages[0].role == 'system' %}{{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}{%- endif %}{%- endif %}{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}{%- for message in messages[::-1] %}{%- set index = (messages|length - 1) - loop.index0 %}{%- if ns.multi_step_tool and message.role == \"user\" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}{%- set ns.multi_step_tool = false %}{%- set ns.last_query_index = index %}{%- endif %}{%- endfor %}{%- for message in messages %}{%- if message.content is string %}{%- set content = message.content %}{%- else %}{%- set content = '' %}{%- endif %}{%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}{{- '<|im_start|>' + message.role + '\\n' + content + '<|im_end|>' + '\\n' }}{%- elif message.role == \"assistant\" %}{%- set reasoning_content = '' %}{%- if message.reasoning_content is string %}{%- set reasoning_content = message.reasoning_content %}{%- else %}{%- if '</think>' in content %}{%- set reasoning_content = content.split('</think>')[0].rstrip('\\n').split('<think>')[-1].lstrip('\\n') %}{%- set content = content.split('</think>')[-1].lstrip('\\n') %}{%- endif %}{%- endif %}{%- if true %}{{- '<|im_start|>' + message.role + '\\n<think>' + reasoning_content.strip('\\n') + '\\n</think>\\n\\n' + content.lstrip('\\n') }}{%- endif %}{%- if message.tool_calls %}{%- for tool_call in message.tool_calls %}{%- if (loop.first and content) or (not loop.first) %}{{- '\\n' }}{%- endif %}{%- if tool_call.function %}{%- set tool_call = tool_call.function %}{%- endif %}{{- '<tool_call>\\n{\"name\": \"' }}{{- tool_call.name }}{{- '\", \"arguments\": ' }}{%- if tool_call.arguments is string %}{{- tool_call.arguments }}{%- else %}{{- tool_call.arguments | tojson }}{%- endif %}{{- '}\\n</tool_call>' }}{%- endfor %}{%- endif %}{{- '<|im_end|>\\n' }}{%- elif message.role == \"tool\" %}{%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}{{- '<|im_start|>user' }}{%- endif %}{{- '\\n<tool_response>\\n' }}{{- content }}{{- '\\n</tool_response>' }}{%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}{{- '<|im_end|>\\n' }}{%- endif %}{%- endif %}{%- endfor %}{%- if add_generation_prompt %}{{- '<|im_start|>assistant\\n' }}{%- set juice_value = thinking_juice if thinking_juice is defined else 2.00 %}{%- set juice_str = '%.2f' | format(juice_value) %}{%- if open_thinking is defined and open_thinking %}{{- '<think>juice = ' + juice_str + '\\n' }}{%- else %}{{- '<think>juice = ' + juice_str + '\\n</think>\\n\\n' }}{%- endif %}{%- endif %}",
334
+ "tokenizer_class": "PreTrainedTokenizerFast"
335
+ }
ymodel3_eval.py ADDED
@@ -0,0 +1,755 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Self-contained ymodel3 inference module.
2
+
3
+ Only depends on: torch, safetensors.
4
+ No dependency on kernel.*, model.ymodel3, transformers.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import json
10
+ import math
11
+ from pathlib import Path
12
+ from typing import Optional, Union
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from safetensors.torch import load_file as load_safetensors
18
+
19
+
20
+ # ── Config ──────────────────────────────────────────────────────────
21
+
22
+
23
+ class YConfig3:
24
+ model_type = "ynet3"
25
+
26
+ def __init__(self, **kwargs):
27
+ self.dropout = kwargs.get("dropout", 0.0)
28
+ self.bos_token_id = kwargs.get("bos_token_id", 151644)
29
+ self.eos_token_id = kwargs.get("eos_token_id", 151645)
30
+ self.pad_token_id = kwargs.get("pad_token_id", 151643)
31
+ self.hidden_act = kwargs.get("hidden_act", "silu")
32
+ self.hidden_size = kwargs.get("hidden_size", 768)
33
+ self.num_hidden_layers = kwargs.get("num_hidden_layers", 8)
34
+ self.max_position_embeddings = kwargs.get("max_position_embeddings", 8192)
35
+ self.vocab_size = kwargs.get("vocab_size", 6400)
36
+ self.rms_norm_eps = kwargs.get("rms_norm_eps", 1e-6)
37
+ self.rope_theta = kwargs.get("rope_theta", 5e4)
38
+ self.rope_scaling = kwargs.get("rope_scaling", None)
39
+ self.dtype = kwargs.get("dtype", "float32")
40
+ self.self_distill = kwargs.get("self_distill", True)
41
+ self.intermediate_size = kwargs.get("intermediate_size", 1536)
42
+ self.expert_intermediate_size = kwargs.get("expert_intermediate_size", None) or self.intermediate_size
43
+ self.n_routed_experts = kwargs.get("n_routed_experts", 0)
44
+ self.moe_topk = kwargs.get("moe_topk", 2)
45
+ self.score_func = kwargs.get("score_func", "softmax")
46
+ self.n_shared_experts = kwargs.get("n_shared_experts", 0)
47
+ self.top_k_layer_dense = kwargs.get("top_k_layer_dense", 1)
48
+ self.aux_loss_alpha = kwargs.get("aux_loss_alpha", 0.02)
49
+ self.seq_aux = kwargs.get("seq_aux", False)
50
+ self.norm_topk_prob = kwargs.get("norm_topk_prob", True)
51
+ self.noisy_expert = kwargs.get("noisy_expert", 0.0)
52
+ self.moe_backend = kwargs.get("moe_backend", "compact")
53
+ self.router_bias_enabled = kwargs.get("router_bias_enabled", True)
54
+ self.router_bias_update_rate = kwargs.get("router_bias_update_rate", 1e-3)
55
+ self.router_bias_clamp = kwargs.get("router_bias_clamp", 5.0)
56
+ self.num_heads = kwargs.get("num_heads", 12)
57
+ self.mla_kv_lora_rank = kwargs.get("mla_kv_lora_rank", 64)
58
+ self.mla_qk_nope_head_dim = kwargs.get("mla_qk_nope_head_dim", 64)
59
+ self.mla_qk_rope_head_dim = kwargs.get("mla_qk_rope_head_dim", 32)
60
+ self.mla_attn_impl = kwargs.get("mla_attn_impl", "absorb")
61
+ self.qkv_lora = kwargs.get("qkv_lora", False)
62
+
63
+ @property
64
+ def head_dim(self) -> int:
65
+ return self.mla_qk_nope_head_dim + self.mla_qk_rope_head_dim
66
+
67
+ def scale_lvl(self, lvl: int = 0):
68
+ if lvl == 0:
69
+ self.hidden_size = 1024
70
+ self.num_hidden_layers = 8
71
+ self.num_heads = 8
72
+ self.mla_kv_lora_rank = 256
73
+ self.mla_qk_nope_head_dim = 192
74
+ self.mla_qk_rope_head_dim = 64
75
+ self.intermediate_size = 2048
76
+ self.expert_intermediate_size = 512
77
+ self.n_routed_experts = 16
78
+ self.moe_topk = 1
79
+ self.n_shared_experts = 0
80
+ self.top_k_layer_dense = 1
81
+ self.router_bias_update_rate = 1e-3
82
+ elif lvl == -1:
83
+ self.hidden_size = 768
84
+ self.num_hidden_layers = 8
85
+ self.num_heads = 6
86
+ self.mla_kv_lora_rank = 128
87
+ self.mla_qk_nope_head_dim = 64
88
+ self.mla_qk_rope_head_dim = 64
89
+ self.intermediate_size = 1536
90
+ self.expert_intermediate_size = 768
91
+ self.n_routed_experts = 0
92
+ self.moe_topk = 2
93
+ self.n_shared_experts = 0
94
+ self.top_k_layer_dense = 8
95
+ elif lvl == -2:
96
+ self.hidden_size = 512
97
+ self.num_hidden_layers = 4
98
+ self.num_heads = 4
99
+ self.mla_kv_lora_rank = 128
100
+ self.mla_qk_nope_head_dim = 64
101
+ self.mla_qk_rope_head_dim = 32
102
+ self.intermediate_size = 1024
103
+ self.expert_intermediate_size = 512
104
+ self.n_routed_experts = 0
105
+ self.moe_topk = 2
106
+ self.n_shared_experts = 0
107
+ self.top_k_layer_dense = 4
108
+ else:
109
+ raise ValueError(f"invalid ymodel3 scale level: {lvl}")
110
+ return self
111
+
112
+ @classmethod
113
+ def from_json_file(cls, path: str) -> "YConfig3":
114
+ with open(path, "r", encoding="utf-8") as f:
115
+ data = json.load(f)
116
+ return cls(**data)
117
+
118
+ @classmethod
119
+ def from_dict(cls, data: dict) -> "YConfig3":
120
+ return cls(**data)
121
+
122
+
123
+ # ── Basic modules ──────────────────────────────────────────────────
124
+
125
+
126
+ class RMSNorm(nn.Module):
127
+ def __init__(self, dim: int, eps: float = 1e-6):
128
+ super().__init__()
129
+ self.eps = eps
130
+ self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
131
+
132
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
133
+ out = x.float() * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
134
+ return (out * self.weight.float()).to(x.dtype)
135
+
136
+
137
+ class SEBlock(nn.Module):
138
+ def __init__(self, dim: int, reduction: int = 16, act: Optional[nn.Module] = None):
139
+ super().__init__()
140
+ reduction = max(reduction, dim // reduction)
141
+ self.se = nn.Sequential(
142
+ nn.Linear(dim, reduction, bias=False),
143
+ act or nn.SiLU(),
144
+ nn.Linear(reduction, dim, bias=False),
145
+ nn.Sigmoid(),
146
+ )
147
+
148
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
149
+ return x * self.se(x)
150
+
151
+
152
+ # ── RoPE helpers ──────────────────────────────────────────────────
153
+
154
+
155
+ def _yarn_linear_ramp(low: float, high: float, dim: int) -> torch.Tensor:
156
+ if low == high:
157
+ high += 0.001
158
+ linear = (torch.arange(dim, dtype=torch.float32) - low) / (high - low)
159
+ return torch.clamp(linear, 0.0, 1.0)
160
+
161
+
162
+ def _yarn_correction_dim(num_rotations: float, dim: int, theta: float, max_position_embeddings: int) -> float:
163
+ return dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi)) / (2 * math.log(theta))
164
+
165
+
166
+ def precompute_freqs_cis(
167
+ dim: int,
168
+ end: int,
169
+ theta: float,
170
+ rope_scaling: Optional[dict] = None,
171
+ ) -> tuple[torch.Tensor, torch.Tensor]:
172
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
173
+ attention_factor = 1.0
174
+ if rope_scaling and str(rope_scaling.get("type", "yarn")).lower() == "yarn":
175
+ factor = float(rope_scaling.get("factor", 1.0))
176
+ if factor > 1.0:
177
+ original = int(rope_scaling.get("original_max_position_embeddings", end))
178
+ beta_fast = float(rope_scaling.get("beta_fast", 32.0))
179
+ beta_slow = float(rope_scaling.get("beta_slow", 1.0))
180
+ low = math.floor(_yarn_correction_dim(beta_fast, dim, theta, original))
181
+ high = math.ceil(_yarn_correction_dim(beta_slow, dim, theta, original))
182
+ ramp = _yarn_linear_ramp(low, high, dim // 2)
183
+ freqs = freqs / factor * (1.0 - ramp) + freqs * ramp
184
+ attention_factor = float(rope_scaling.get("attention_factor", 1.0))
185
+ t = torch.arange(end)
186
+ freqs = torch.outer(t, freqs).float()
187
+ freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1) * attention_factor
188
+ freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1) * attention_factor
189
+ return freqs_cos, freqs_sin
190
+
191
+
192
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
193
+ return torch.cat((-x[..., x.shape[-1] // 2 :], x[..., : x.shape[-1] // 2]), dim=-1)
194
+
195
+
196
+ def apply_rope_to_single(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
197
+ if cos.dim() == 2:
198
+ cos = cos.unsqueeze(0).unsqueeze(0)
199
+ sin = sin.unsqueeze(0).unsqueeze(0)
200
+ elif cos.dim() == 3:
201
+ cos = cos.unsqueeze(1)
202
+ sin = sin.unsqueeze(1)
203
+ return (x * cos) + (rotate_half(x) * sin)
204
+
205
+
206
+ # ── Attention ──────────────────────────────────────────────────────
207
+
208
+
209
+ class MLGA(nn.Module):
210
+ """Multihead Latent Gated Attention"""
211
+
212
+ def __init__(self, config: YConfig3, layer_id: int):
213
+ super().__init__()
214
+ self.layer_id = layer_id
215
+ self.hidden_size = config.hidden_size
216
+ self.num_heads = config.num_heads
217
+ self.dropout = config.dropout
218
+ self.kv_lora_rank = config.mla_kv_lora_rank
219
+ self.qk_nope_head_dim = config.mla_qk_nope_head_dim
220
+ self.qk_rope_head_dim = config.mla_qk_rope_head_dim
221
+ self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
222
+ self.attn_impl = config.mla_attn_impl
223
+ self.softmax_scale = self.qk_head_dim ** -0.5
224
+ self.out_dim = self.num_heads * self.kv_lora_rank
225
+
226
+ self.wq = nn.Linear(self.hidden_size, self.num_heads * self.qk_head_dim, bias=False)
227
+ self.wkv_a = nn.Linear(self.hidden_size, self.kv_lora_rank + self.qk_rope_head_dim, bias=False)
228
+ self.kv_norm = RMSNorm(self.kv_lora_rank, config.rms_norm_eps)
229
+ self.wkv_b = nn.Linear(self.kv_lora_rank, self.num_heads * self.qk_nope_head_dim, bias=False)
230
+ self.z_proj = nn.Linear(self.hidden_size, self.out_dim, bias=False)
231
+ self.o_proj = nn.Linear(self.out_dim, self.hidden_size, bias=False)
232
+
233
+ def _project_q(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
234
+ bsz, seq_len, _ = x.shape
235
+ q = self.wq(x)
236
+ q = q.reshape(bsz, seq_len, self.num_heads, self.qk_head_dim)
237
+ return q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
238
+
239
+ def _project_kv(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
240
+ raw = self.wkv_a(x)
241
+ c_kv, k_pe = raw.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
242
+ c_kv = self.kv_norm(c_kv)
243
+ k_pe = apply_rope_to_single(k_pe.unsqueeze(1), cos, sin).permute(0, 2, 1, 3)
244
+ return c_kv, k_pe
245
+
246
+ def _explicit_kv(self, c_kv: torch.Tensor, k_pe: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
247
+ bsz, seq_len, _ = c_kv.shape
248
+ k_nope = self.wkv_b(c_kv).reshape(bsz, seq_len, self.num_heads, self.qk_nope_head_dim)
249
+ k = torch.cat([k_nope, k_pe.expand(-1, -1, self.num_heads, -1)], dim=-1)
250
+ v = c_kv.unsqueeze(2).expand(-1, -1, self.num_heads, -1)
251
+ return k, v
252
+
253
+ def _attention_mask(self, attention_mask: Optional[torch.Tensor], bsz: int, seq_len: int, total_len: int):
254
+ if attention_mask is None:
255
+ return None
256
+ if attention_mask.shape[-1] != total_len:
257
+ attention_mask = attention_mask[..., -total_len:]
258
+ mask = attention_mask.reshape(bsz, 1, 1, total_len).bool()
259
+ return mask.expand(bsz, self.num_heads, seq_len, total_len)
260
+
261
+ def _forward_sdpa(
262
+ self,
263
+ q_nope: torch.Tensor,
264
+ q_pe: torch.Tensor,
265
+ c_kv: torch.Tensor,
266
+ k_pe: torch.Tensor,
267
+ z: torch.Tensor,
268
+ attention_mask: Optional[torch.Tensor],
269
+ ) -> torch.Tensor:
270
+ bsz, seq_len, _, _ = q_nope.shape
271
+ total_len = c_kv.shape[1]
272
+ k, v = self._explicit_kv(c_kv, k_pe)
273
+ q = torch.cat([q_nope, q_pe], dim=-1).permute(0, 2, 1, 3)
274
+ k = k.permute(0, 2, 1, 3)
275
+ v = v.permute(0, 2, 1, 3)
276
+ attn_mask = self._attention_mask(attention_mask, bsz, seq_len, total_len)
277
+ is_causal = attention_mask is None and seq_len == total_len
278
+ out = F.scaled_dot_product_attention(
279
+ q, k, v,
280
+ attn_mask=attn_mask,
281
+ dropout_p=self.dropout if self.training else 0.0,
282
+ is_causal=is_causal,
283
+ scale=self.softmax_scale,
284
+ )
285
+ out = out.permute(0, 2, 1, 3).reshape(bsz, seq_len, self.out_dim)
286
+ out = out * torch.sigmoid(z)
287
+ return self.o_proj(out)
288
+
289
+ def _forward_absorb(
290
+ self,
291
+ q_nope: torch.Tensor,
292
+ q_pe: torch.Tensor,
293
+ c_kv: torch.Tensor,
294
+ k_pe: torch.Tensor,
295
+ z: torch.Tensor,
296
+ attention_mask: Optional[torch.Tensor],
297
+ ) -> torch.Tensor:
298
+ bsz, seq_len, _, _ = q_nope.shape
299
+ total_len = c_kv.shape[1]
300
+ w = self.wkv_b.weight.reshape(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank)
301
+ q_nope_c = torch.einsum("bshd,hdc->bshc", q_nope, w)
302
+ scores = torch.einsum("bshc,btc->bsht", q_nope_c, c_kv)
303
+ scores = scores + torch.einsum("bshr,btr->bsht", q_pe, k_pe.squeeze(2))
304
+ scores = scores * self.softmax_scale
305
+
306
+ causal = torch.full((seq_len, seq_len), float("-inf"), device=scores.device, dtype=scores.dtype)
307
+ causal = torch.triu(causal, diagonal=1).reshape(1, seq_len, 1, seq_len)
308
+ scores = scores + F.pad(causal, (total_len - seq_len, 0), value=0.0)
309
+ if attention_mask is not None:
310
+ if attention_mask.shape[-1] != total_len:
311
+ attention_mask = attention_mask[..., -total_len:]
312
+ scores = scores + (1.0 - attention_mask.reshape(bsz, 1, 1, total_len).float()) * -1e9
313
+ probs = torch.softmax(scores.float(), dim=-1).to(q_nope.dtype)
314
+ out = torch.einsum("bsht,btc->bshc", probs, c_kv).reshape(bsz, seq_len, self.out_dim)
315
+ out = out * torch.sigmoid(z)
316
+ return self.o_proj(out)
317
+
318
+ def forward(
319
+ self,
320
+ x: torch.Tensor,
321
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
322
+ past_key_values: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
323
+ attention_mask: Optional[torch.Tensor] = None,
324
+ use_cache: bool = False,
325
+ **kwargs,
326
+ ) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, torch.Tensor]]]:
327
+ bsz, seq_len, _ = x.shape
328
+ cos, sin = position_embeddings
329
+ if cos.dim() == 2:
330
+ cos = cos[:seq_len, : self.qk_rope_head_dim]
331
+ sin = sin[:seq_len, : self.qk_rope_head_dim]
332
+ else:
333
+ cos = cos[:, :seq_len, : self.qk_rope_head_dim]
334
+ sin = sin[:, :seq_len, : self.qk_rope_head_dim]
335
+ q_nope, q_pe = self._project_q(x)
336
+ q_pe = apply_rope_to_single(q_pe.permute(0, 2, 1, 3), cos, sin).permute(0, 2, 1, 3)
337
+ c_kv, k_pe = self._project_kv(x, cos, sin)
338
+ z = self.z_proj(x)
339
+
340
+ if past_key_values is not None:
341
+ past_c, past_pe = past_key_values
342
+ c_kv = torch.cat([past_c, c_kv], dim=1)
343
+ k_pe = torch.cat([past_pe, k_pe], dim=1)
344
+ new_past = (c_kv, k_pe) if use_cache else None
345
+
346
+ if self.attn_impl == "naive":
347
+ out = self._forward_sdpa(q_nope, q_pe, c_kv, k_pe, z, attention_mask)
348
+ else:
349
+ out = self._forward_absorb(q_nope, q_pe, c_kv, k_pe, z, attention_mask)
350
+ out = F.dropout(out, p=self.dropout, training=self.training)
351
+ return out, new_past
352
+
353
+
354
+ # ── FFN / MoE ──────────────────────────────────────────────────────
355
+
356
+
357
+ _ACT_FNS = {
358
+ "silu": F.silu,
359
+ "swish": F.silu,
360
+ "relu": F.relu,
361
+ "gelu": lambda x: F.gelu(x, approximate="tanh"),
362
+ "sigmoid": torch.sigmoid,
363
+ }
364
+
365
+ _ACT_MODULES = {
366
+ "silu": nn.SiLU,
367
+ "swish": nn.SiLU,
368
+ "relu": nn.ReLU,
369
+ "gelu": lambda: nn.GELU(approximate="tanh"),
370
+ "sigmoid": nn.Sigmoid,
371
+ }
372
+
373
+
374
+ class DenseFFN(nn.Module):
375
+ def __init__(self, config: YConfig3, intermediate_size: Optional[int] = None):
376
+ super().__init__()
377
+ inter = intermediate_size or config.intermediate_size
378
+ self.up_proj = nn.Linear(config.hidden_size, inter, bias=False)
379
+ self.gate_proj = nn.Linear(config.hidden_size, inter, bias=False)
380
+ self.down_proj = nn.Linear(inter, config.hidden_size, bias=False)
381
+ self.hidden_act = config.hidden_act
382
+ self.act = _ACT_FNS.get(config.hidden_act, F.silu)
383
+ self.dropout = config.dropout
384
+
385
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
386
+ up, gate = self.up_proj(x), self.gate_proj(x)
387
+ up = self.act(gate) * up
388
+ up = F.dropout(up, p=self.dropout, training=self.training)
389
+ return self.down_proj(up)
390
+
391
+
392
+ class MoEGate(nn.Module):
393
+ def __init__(self, config: YConfig3):
394
+ super().__init__()
395
+ self.n_routed_experts = config.n_routed_experts
396
+ self.topk = min(config.moe_topk, max(1, config.n_routed_experts))
397
+ self.score_func = config.score_func
398
+ self.norm_topk_prob = config.norm_topk_prob
399
+ self.aux_loss_alpha = config.aux_loss_alpha
400
+ self.seq_aux = config.seq_aux
401
+ self.router_bias_enabled = config.router_bias_enabled
402
+ self.router_bias_update_rate = config.router_bias_update_rate
403
+ self.router_bias_clamp = config.router_bias_clamp
404
+ self.weight = nn.Linear(int(config.hidden_size), int(self.n_routed_experts), bias=False)
405
+ if self.router_bias_enabled:
406
+ self.register_buffer("router_bias", torch.zeros(self.n_routed_experts), persistent=True)
407
+ else:
408
+ self.register_buffer("router_bias", None, persistent=False)
409
+
410
+ def forward(self, x: torch.Tensor, aux_mask: Optional[torch.Tensor] = None):
411
+ bsz, seq_len, hidden = x.shape
412
+ flat = x.reshape(-1, hidden)
413
+ route_logits = self.weight(flat)
414
+ if self.score_func == "softmax":
415
+ route_scores = torch.softmax(route_logits.float(), dim=-1).to(x.dtype)
416
+ elif self.score_func == "sigmoid":
417
+ route_scores = torch.sigmoid(route_logits.float()).to(x.dtype)
418
+ else:
419
+ raise ValueError(f"unsupported MoE score_func: {self.score_func}")
420
+
421
+ choice_scores = route_scores
422
+ if self.router_bias is not None:
423
+ choice_scores = choice_scores + self.router_bias.to(dtype=choice_scores.dtype).unsqueeze(0)
424
+
425
+ topk_idx = torch.topk(choice_scores, k=self.topk, dim=-1, sorted=False).indices
426
+ topk_weight = route_scores.gather(1, topk_idx)
427
+ if self.topk > 1 and self.norm_topk_prob:
428
+ denom = topk_weight.float().sum(dim=-1, keepdim=True) + 1e-20
429
+ topk_weight = (topk_weight.float() / denom).to(x.dtype)
430
+
431
+ aux_loss = x.new_zeros((), dtype=x.dtype)
432
+ return (
433
+ topk_idx.reshape(bsz, seq_len, self.topk),
434
+ topk_weight.reshape(bsz, seq_len, self.topk),
435
+ aux_loss,
436
+ )
437
+
438
+
439
+ def _torch_moe_swiglu(
440
+ x: torch.Tensor,
441
+ topk_idx: torch.Tensor,
442
+ topk_weight: torch.Tensor,
443
+ w_up: torch.Tensor,
444
+ w_down: torch.Tensor,
445
+ activation: str = "silu",
446
+ ) -> torch.Tensor:
447
+ """Pure PyTorch MoE SwiGLU forward (inference only, no noisy_expert)."""
448
+ original_shape = x.shape
449
+ x_flat = x.reshape(-1, x.shape[-1])
450
+ idx = topk_idx.reshape(x_flat.shape[0], -1)
451
+ weight = topk_weight.reshape(x_flat.shape[0], -1)
452
+ y = torch.zeros_like(x_flat)
453
+ n_experts = w_up.shape[0]
454
+ inter = w_down.shape[-1]
455
+ act_fn = _ACT_FNS.get(activation, F.silu)
456
+ for expert_id in range(n_experts):
457
+ token_pos, choice_pos = torch.where(idx == expert_id)
458
+ if token_pos.numel() == 0:
459
+ continue
460
+ inp = x_flat[token_pos]
461
+ uv = F.linear(inp, w_up[expert_id])
462
+ up, gate = uv.split(inter, dim=-1)
463
+ hidden = act_fn(gate) * up
464
+ out = F.linear(hidden, w_down[expert_id])
465
+ route_w = weight[token_pos, choice_pos].unsqueeze(-1)
466
+ y.index_add_(0, token_pos, out * route_w)
467
+ return y.reshape(original_shape)
468
+
469
+
470
+ class YMoE(nn.Module):
471
+ """Pure PyTorch eval MoE (no Triton dependency)."""
472
+
473
+ def __init__(self, config: YConfig3, layer_id: int):
474
+ super().__init__()
475
+ self.layer_id = layer_id
476
+ self.hidden_size = config.hidden_size
477
+ self.expert_intermediate_size = config.expert_intermediate_size
478
+ self.intermediate_size = self.expert_intermediate_size
479
+ self.n_routed_experts = config.n_routed_experts
480
+ self.use_moe = self.n_routed_experts > 0 and layer_id >= config.top_k_layer_dense
481
+ self.noisy_expert = config.noisy_expert
482
+ if not self.use_moe:
483
+ self.dense = DenseFFN(config)
484
+ self.gate = None
485
+ self.w_up = None
486
+ self.w_down = None
487
+ return
488
+ self.dense = None
489
+ self.gate = MoEGate(config)
490
+ self.w_up = nn.Parameter(torch.empty(self.n_routed_experts, 2 * self.expert_intermediate_size, self.hidden_size))
491
+ self.w_down = nn.Parameter(torch.empty(self.n_routed_experts, self.hidden_size, self.expert_intermediate_size))
492
+ nn.init.kaiming_uniform_(self.w_up, a=math.sqrt(5))
493
+ nn.init.kaiming_uniform_(self.w_down, a=math.sqrt(5))
494
+
495
+ def forward(self, x: torch.Tensor, aux_mask: Optional[torch.Tensor] = None):
496
+ if not self.use_moe:
497
+ return self.dense(x), None
498
+ topk_idx, topk_weight, aux_loss = self.gate(x, aux_mask)
499
+ y = _torch_moe_swiglu(x, topk_idx, topk_weight, self.w_up, self.w_down, activation="silu")
500
+ return y, aux_loss
501
+
502
+
503
+ # ── Transformer block ──────────────────────────────────────────────
504
+
505
+
506
+ class YBlock3(nn.Module):
507
+ def __init__(self, config: YConfig3, layer_id: int):
508
+ super().__init__()
509
+ self.input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps)
510
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps)
511
+ self.attn = MLGA(config, layer_id)
512
+ self.ffn = YMoE(config, layer_id)
513
+ act_module = _ACT_MODULES.get(config.hidden_act, nn.SiLU)
514
+ self.se1 = SEBlock(config.hidden_size, act=act_module() if isinstance(act_module, type) else act_module())
515
+ self.se2 = SEBlock(config.hidden_size, act=nn.SiLU())
516
+
517
+ def forward(
518
+ self,
519
+ x: torch.Tensor,
520
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
521
+ past_key_values=None,
522
+ use_cache: bool = False,
523
+ attention_mask: Optional[torch.Tensor] = None,
524
+ aux_mask: Optional[torch.Tensor] = None,
525
+ **kwargs,
526
+ ):
527
+ x0 = self.se1(self.input_layernorm(x))
528
+ attn_out, past = self.attn(
529
+ x0,
530
+ position_embeddings,
531
+ past_key_values=past_key_values,
532
+ attention_mask=attention_mask,
533
+ use_cache=use_cache,
534
+ )
535
+ x = x + attn_out
536
+ x0 = self.se2(self.post_attention_layernorm(x))
537
+ ffn_out, aux_loss = self.ffn(x0, aux_mask)
538
+ x = x + ffn_out
539
+ return x, past, aux_loss
540
+
541
+
542
+ # ── Full model ────────────────────────────────────────────────────
543
+
544
+
545
+ class YModel3(nn.Module):
546
+ def __init__(self, config: YConfig3):
547
+ super().__init__()
548
+ self.config = config
549
+ self.vocab_size = config.vocab_size
550
+ self.num_layers = config.num_hidden_layers
551
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
552
+ self.dropout = config.dropout
553
+ self.use_self_distill = config.self_distill
554
+ self.layers = nn.ModuleList([YBlock3(config, i) for i in range(config.num_hidden_layers)])
555
+ self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps)
556
+ freqs_cos, freqs_sin = precompute_freqs_cis(
557
+ dim=config.mla_qk_rope_head_dim,
558
+ end=config.max_position_embeddings,
559
+ theta=config.rope_theta,
560
+ rope_scaling=config.rope_scaling,
561
+ )
562
+ self.register_buffer("freqs_cos", freqs_cos, persistent=False)
563
+ self.register_buffer("freqs_sin", freqs_sin, persistent=False)
564
+
565
+ def forward(
566
+ self,
567
+ input_ids: Optional[torch.Tensor] = None,
568
+ attention_mask: Optional[torch.Tensor] = None,
569
+ past_key_values: Optional[list] = None,
570
+ use_cache: bool = False,
571
+ cache_position: Optional[torch.LongTensor] = None,
572
+ position_ids: Optional[torch.LongTensor] = None,
573
+ **kwargs,
574
+ ):
575
+ bsz, seq_len = input_ids.shape
576
+ if use_cache and past_key_values is None:
577
+ past_key_values = [None] * self.num_layers
578
+ if cache_position is None:
579
+ if past_key_values is not None and past_key_values[0] is not None:
580
+ past_seen = past_key_values[0][0].shape[1]
581
+ else:
582
+ past_seen = 0
583
+ cache_position = torch.arange(past_seen, past_seen + seq_len, device=input_ids.device)
584
+
585
+ x = F.dropout(self.embed_tokens(input_ids), p=self.dropout, training=self.training)
586
+ if position_ids is None:
587
+ position_ids = cache_position
588
+ position_embeddings = (self.freqs_cos[position_ids].to(x.device), self.freqs_sin[position_ids].to(x.device))
589
+ aux_mask = None
590
+ new_past = [] if use_cache else None
591
+ aux_loss = None
592
+
593
+ for i, layer in enumerate(self.layers):
594
+ past = past_key_values[i] if past_key_values is not None else None
595
+ x, layer_past, layer_aux = layer(
596
+ x,
597
+ position_embeddings=position_embeddings,
598
+ past_key_values=past,
599
+ attention_mask=attention_mask,
600
+ use_cache=use_cache,
601
+ aux_mask=aux_mask,
602
+ )
603
+ if use_cache:
604
+ new_past.append(layer_past)
605
+ if self.training and layer_aux is not None:
606
+ aux_loss = layer_aux if aux_loss is None else aux_loss + layer_aux
607
+
608
+ return self.norm(x), new_past, None, aux_loss
609
+
610
+
611
+ class _InferenceOutput:
612
+ """Simple container for model outputs (replaces transformers CausalLMOutputWithPast)."""
613
+
614
+ __slots__ = ("last_hidden_state", "logits", "past_key_values", "dist_loss", "aux_loss")
615
+
616
+ def __init__(self):
617
+ self.last_hidden_state = None
618
+ self.logits = None
619
+ self.past_key_values = None
620
+ self.dist_loss = None
621
+ self.aux_loss = None
622
+
623
+ def __setitem__(self, key, value):
624
+ setattr(self, key, value)
625
+
626
+
627
+ class YForCausalLM3(nn.Module):
628
+ """Pure PyTorch CausalLM wrapper for ymodel3 inference (no transformers dependency)."""
629
+
630
+ config_class = YConfig3
631
+
632
+ def __init__(self, config: Optional[YConfig3] = None):
633
+ super().__init__()
634
+ self.config = config or YConfig3()
635
+ self.model = YModel3(self.config)
636
+ self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False)
637
+ self.model.embed_tokens.weight = self.lm_head.weight
638
+ self.OUT = _InferenceOutput()
639
+ dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}.get(self.config.dtype)
640
+ if dtype is not None:
641
+ self.to(dtype)
642
+
643
+ def forward(
644
+ self,
645
+ input_ids: Optional[torch.Tensor] = None,
646
+ attention_mask: Optional[torch.Tensor] = None,
647
+ past_key_values: Optional[list] = None,
648
+ use_cache: bool = False,
649
+ logits_to_keep: Union[int, torch.Tensor] = 0,
650
+ cache_position: Optional[torch.LongTensor] = None,
651
+ **kwargs,
652
+ ):
653
+ h, past_kvs, dist_loss, aux_loss = self.model(
654
+ input_ids=input_ids,
655
+ attention_mask=attention_mask,
656
+ past_key_values=past_key_values,
657
+ use_cache=use_cache,
658
+ cache_position=cache_position,
659
+ position_ids=kwargs.get("position_ids", None),
660
+ )
661
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
662
+ logits = self.lm_head(h[:, slice_indices, :])
663
+ self.OUT.__setitem__("last_hidden_state", h)
664
+ self.OUT.__setitem__("logits", logits)
665
+ self.OUT.__setitem__("past_key_values", past_kvs)
666
+ self.OUT.__setitem__("dist_loss", dist_loss)
667
+ self.OUT.__setitem__("aux_loss", aux_loss)
668
+ return self.OUT
669
+
670
+ def generate(
671
+ self,
672
+ inputs,
673
+ attention_mask=None,
674
+ max_new_tokens=8192,
675
+ temperature=0.85,
676
+ top_p=0.85,
677
+ top_k=50,
678
+ eos_token_id=None,
679
+ streamer=None,
680
+ use_cache=True,
681
+ num_return_sequences=1,
682
+ do_sample=True,
683
+ repetition_penalty=1.0,
684
+ **kwargs,
685
+ ):
686
+ input_ids = kwargs.get("input_ids", inputs).repeat(num_return_sequences, 1)
687
+ attention_mask = attention_mask.repeat(num_return_sequences, 1) if attention_mask is not None else None
688
+ past_key_values = None
689
+ if streamer:
690
+ streamer.put(input_ids.cpu())
691
+ with torch.no_grad():
692
+ for _ in range(max_new_tokens):
693
+ if use_cache and past_key_values is not None:
694
+ outputs = self.forward(input_ids[:, -1:], None, past_key_values, use_cache=use_cache)
695
+ else:
696
+ outputs = self.forward(input_ids, attention_mask, past_key_values, use_cache=use_cache)
697
+ logits = outputs.logits[:, -1, :] / temperature
698
+ if repetition_penalty != 1.0:
699
+ for i in range(input_ids.shape[0]):
700
+ logits[i, torch.unique(input_ids[i])] /= repetition_penalty
701
+ if top_k > 0:
702
+ logits[logits < torch.topk(logits, top_k)[0][..., -1, None]] = -float("inf")
703
+ if top_p < 1.0:
704
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
705
+ mask = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) > top_p
706
+ mask[..., 1:], mask[..., 0] = mask[..., :-1].clone(), 0
707
+ logits[mask.scatter(1, sorted_indices, mask)] = -float("inf")
708
+ next_token = torch.multinomial(torch.softmax(logits, dim=-1), 1) if do_sample else torch.argmax(logits, dim=-1, keepdim=True)
709
+ input_ids = torch.cat([input_ids, next_token], dim=-1)
710
+ past_key_values = outputs.past_key_values if use_cache else None
711
+ if streamer:
712
+ streamer.put(next_token.cpu())
713
+ if eos_token_id and (next_token == eos_token_id).any():
714
+ break
715
+ if streamer:
716
+ streamer.end()
717
+ return input_ids
718
+
719
+
720
+ # ── Loading utilities ──────────────────────────────────────────────
721
+
722
+
723
+ def _load_state_dict(path: Union[str, Path]) -> dict[str, torch.Tensor]:
724
+ path = Path(path)
725
+ if path.is_dir():
726
+ safetensors_path = path / "model.safetensors"
727
+ bin_path = path / "pytorch_model.bin"
728
+ if safetensors_path.exists():
729
+ path = safetensors_path
730
+ elif bin_path.exists():
731
+ path = bin_path
732
+ else:
733
+ raise FileNotFoundError(f"no model.safetensors or pytorch_model.bin found in {path}")
734
+ if path.suffix == ".safetensors":
735
+ return load_safetensors(str(path), device="cpu")
736
+ return torch.load(path, map_location="cpu", weights_only=True)
737
+
738
+
739
+ def load_ymodel3_eval(path: Union[str, Path], config: Optional[YConfig3] = None, strict: bool = True) -> YForCausalLM3:
740
+ if config is None:
741
+ config_path = Path(path) / "config.json" if Path(path).is_dir() else Path(path).with_name("config.json")
742
+ if not config_path.exists():
743
+ raise FileNotFoundError("config is required when config.json is not next to the checkpoint")
744
+ config = YConfig3.from_json_file(str(config_path))
745
+ model = YForCausalLM3(config)
746
+ state = _load_state_dict(path)
747
+ model.load_state_dict(state, strict=strict)
748
+ model.eval()
749
+ return model
750
+
751
+
752
+ # ── Backward-compatible aliases ────────────────────────────────────
753
+
754
+ YModel3Eval = YModel3
755
+ YForCausalLM3Eval = YForCausalLM3