Spico commited on
Commit
1b4cc3c
·
verified ·
1 Parent(s): 259deaa

Upload folder using huggingface_hub

Browse files
README.md ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ tags:
6
+ - MoE
7
+ ---
8
+ # LLaMA-MoE-v1-3.5B (4/16) SFT
9
+
10
+ [[💻 Code]](https://github.com/pjlab-sys4nlp/llama-moe) | [[📜 Technical Report]](https://github.com/pjlab-sys4nlp/llama-moe/blob/main/docs/LLaMA_MoE.pdf)
11
+
12
+ This is the supervised fine-tuned version of [LLaMA-MoE-v1-3_5B-4_16](https://huggingface.co/llama-moe/LLaMA-MoE-v1-3_5B-4_16) on [Deita-6k](https://huggingface.co/datasets/hkust-nlp/deita-6k-v0) for 2 epochs.
13
+
14
+
15
+ | Model | \#Activated Experts | \#Experts | \#Activated Params | Foundation Model | SFT Model |
16
+ | :------------------------ | :-----------------: | :-------: | :----------------: | :---------------------------------------------------------------: | :------------------------------------------------------------------: |
17
+ | **LLaMA-MoE-3.0B** | 2 | 16 | 3.0B | [🤗 base](https://huggingface.co/llama-moe/LLaMA-MoE-v1-3_0B-2_16) | [🤗 SFT](https://huggingface.co/llama-moe/LLaMA-MoE-v1-3_0B-2_16-sft) |
18
+ | **LLaMA-MoE-3.5B (4/16)** | 4 | 16 | 3.5B | [🤗 base](https://huggingface.co/llama-moe/LLaMA-MoE-v1-3_5B-4_16) | [🤗 SFT](https://huggingface.co/llama-moe/LLaMA-MoE-v1-3_5B-4_16-sft) |
19
+ | **LLaMA-MoE-3.5B (2/8)** | 2 | 8 | 3.5B | [🤗 base](https://huggingface.co/llama-moe/LLaMA-MoE-v1-3_5B-2_8) | [🤗 SFT](https://huggingface.co/llama-moe/LLaMA-MoE-v1-3_5B-2_8-sft) |
20
+
21
+
22
+ ## 🚀 QuickStart
23
+
24
+ ```python
25
+ # python>=3.10
26
+
27
+ import torch
28
+ from transformers import AutoTokenizer, AutoModelForCausalLM
29
+
30
+ model_dir = "llama-moe/LLaMA-MoE-v1-3_5B-4_16-sft"
31
+ tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
32
+ model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype=torch.bfloat16, trust_remote_code=True)
33
+ model.eval()
34
+ model.cuda()
35
+
36
+ input_text = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. human: Give me a three-day plan in Suzhou. gpt:"
37
+ inputs = tokenizer(input_text, return_tensors="pt")
38
+ input_ids = inputs["input_ids"].cuda()
39
+
40
+ pred = model.generate(input_ids, max_length=100, temperature=1.0, do_sample=True, use_cache=True)
41
+ print(tokenizer.decode(pred.cpu()[0], skip_special_tokens=True))
42
+ """
43
+ Sure, I can provide you with a three-day itinerary in Suzhou. Here's what we can do:
44
+
45
+ Day 1:
46
+
47
+ * Visit Suzhou Industrial Park, a major commercial and manufacturing district ...
48
+ """
49
+ ```
50
+
51
+ ## 📊 Performance
52
+
53
+ | Model | MMLU | ARC-c | HellaSeag | TruthfulQA | MT-Bench |
54
+ | :------------------------------------- | :---: | :---: | :-------: | :--------: | :------: |
55
+ | Sheared LLaMA-2.7B ShareGPT | 28.41 | 41.04 | 71.21 | 47.65 | 3.79 |
56
+ | Sheared LLaMA-2.7B Deita6K (Our Impl.) | 25.24 | 43.69 | 71.70 | 49.00 | 4.06 |
57
+ | LLaMA-MoE-v1-3.0B (2/16) | 23.61 | 43.43 | 72.28 | 44.24 | 4.15 |
58
+ | LLaMA-MoE-v1-3.5B (4/16) | 26.49 | 48.29 | 75.10 | 45.91 | 4.60 |
59
+ | LLaMA-MoE-v1-3.5B (2/8) | 25.53 | 45.99 | 74.95 | 44.39 | 4.72 |
60
+
61
+ ## 📃 Citation
62
+
63
+ ```bibtex
64
+ @misc{llama-moe2023,
65
+ title={LLaMA-MoE: Building Mixture-of-Experts from LLaMA with Continual Pre-training},
66
+ author={LLaMA-MoE Team},
67
+ year={2023},
68
+ publisher={Dec}
69
+ }
70
+ ```
config.json ADDED
@@ -0,0 +1,629 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/mnt/petrelfs/zhutong/llama-moe-models/LLaMA-MoE-v1-3_5B-4_16-new",
3
+ "add_weight_norm": false,
4
+ "architectures": [
5
+ "LlamaMoEForCausalLM"
6
+ ],
7
+ "attention_bias": false,
8
+ "attention_dropout": 0.0,
9
+ "auto_map": {
10
+ "AutoConfig": "configuration_llama_moe.LlamaMoEConfig",
11
+ "AutoModel": "modeling_llama_moe_hf.LlamaMoEModel",
12
+ "AutoModelForCausalLM": "modeling_llama_moe_hf.LlamaMoEForCausalLM"
13
+ },
14
+ "bos_token_id": 1,
15
+ "calculator_type": "UniversalCalculator",
16
+ "capacity_factor": 1.25,
17
+ "drop_tokens": true,
18
+ "dropped_padding": "zero",
19
+ "eos_token_id": 2,
20
+ "gate_add_noise": true,
21
+ "gate_balance_loss_weight": 0.01,
22
+ "gate_network": "mlp",
23
+ "gate_noise_epsilon": 0.01,
24
+ "gate_type": "TopKBalancedNoisyGate",
25
+ "gate_use_balance": true,
26
+ "gate_use_softmax": true,
27
+ "gates": "mlp",
28
+ "hidden_act": "silu",
29
+ "hidden_size": 4096,
30
+ "initializer_range": 0.02,
31
+ "intermediate_size": 11008,
32
+ "max_position_embeddings": 4096,
33
+ "model_type": "llama_moe",
34
+ "multiply_gate_scores": true,
35
+ "num_attention_heads": 32,
36
+ "num_experts": 16,
37
+ "num_hidden_layers": 32,
38
+ "num_key_value_heads": 32,
39
+ "num_selects": 4,
40
+ "pad_token_id": 0,
41
+ "pretraining_tp": 1,
42
+ "rms_norm_eps": 1e-05,
43
+ "rope_scaling": null,
44
+ "rope_theta": 10000.0,
45
+ "score_scale_factor": 4.0,
46
+ "size_experts": [
47
+ [
48
+ 688,
49
+ 688,
50
+ 688,
51
+ 688,
52
+ 688,
53
+ 688,
54
+ 688,
55
+ 688,
56
+ 688,
57
+ 688,
58
+ 688,
59
+ 688,
60
+ 688,
61
+ 688,
62
+ 688,
63
+ 688
64
+ ],
65
+ [
66
+ 688,
67
+ 688,
68
+ 688,
69
+ 688,
70
+ 688,
71
+ 688,
72
+ 688,
73
+ 688,
74
+ 688,
75
+ 688,
76
+ 688,
77
+ 688,
78
+ 688,
79
+ 688,
80
+ 688,
81
+ 688
82
+ ],
83
+ [
84
+ 688,
85
+ 688,
86
+ 688,
87
+ 688,
88
+ 688,
89
+ 688,
90
+ 688,
91
+ 688,
92
+ 688,
93
+ 688,
94
+ 688,
95
+ 688,
96
+ 688,
97
+ 688,
98
+ 688,
99
+ 688
100
+ ],
101
+ [
102
+ 688,
103
+ 688,
104
+ 688,
105
+ 688,
106
+ 688,
107
+ 688,
108
+ 688,
109
+ 688,
110
+ 688,
111
+ 688,
112
+ 688,
113
+ 688,
114
+ 688,
115
+ 688,
116
+ 688,
117
+ 688
118
+ ],
119
+ [
120
+ 688,
121
+ 688,
122
+ 688,
123
+ 688,
124
+ 688,
125
+ 688,
126
+ 688,
127
+ 688,
128
+ 688,
129
+ 688,
130
+ 688,
131
+ 688,
132
+ 688,
133
+ 688,
134
+ 688,
135
+ 688
136
+ ],
137
+ [
138
+ 688,
139
+ 688,
140
+ 688,
141
+ 688,
142
+ 688,
143
+ 688,
144
+ 688,
145
+ 688,
146
+ 688,
147
+ 688,
148
+ 688,
149
+ 688,
150
+ 688,
151
+ 688,
152
+ 688,
153
+ 688
154
+ ],
155
+ [
156
+ 688,
157
+ 688,
158
+ 688,
159
+ 688,
160
+ 688,
161
+ 688,
162
+ 688,
163
+ 688,
164
+ 688,
165
+ 688,
166
+ 688,
167
+ 688,
168
+ 688,
169
+ 688,
170
+ 688,
171
+ 688
172
+ ],
173
+ [
174
+ 688,
175
+ 688,
176
+ 688,
177
+ 688,
178
+ 688,
179
+ 688,
180
+ 688,
181
+ 688,
182
+ 688,
183
+ 688,
184
+ 688,
185
+ 688,
186
+ 688,
187
+ 688,
188
+ 688,
189
+ 688
190
+ ],
191
+ [
192
+ 688,
193
+ 688,
194
+ 688,
195
+ 688,
196
+ 688,
197
+ 688,
198
+ 688,
199
+ 688,
200
+ 688,
201
+ 688,
202
+ 688,
203
+ 688,
204
+ 688,
205
+ 688,
206
+ 688,
207
+ 688
208
+ ],
209
+ [
210
+ 688,
211
+ 688,
212
+ 688,
213
+ 688,
214
+ 688,
215
+ 688,
216
+ 688,
217
+ 688,
218
+ 688,
219
+ 688,
220
+ 688,
221
+ 688,
222
+ 688,
223
+ 688,
224
+ 688,
225
+ 688
226
+ ],
227
+ [
228
+ 688,
229
+ 688,
230
+ 688,
231
+ 688,
232
+ 688,
233
+ 688,
234
+ 688,
235
+ 688,
236
+ 688,
237
+ 688,
238
+ 688,
239
+ 688,
240
+ 688,
241
+ 688,
242
+ 688,
243
+ 688
244
+ ],
245
+ [
246
+ 688,
247
+ 688,
248
+ 688,
249
+ 688,
250
+ 688,
251
+ 688,
252
+ 688,
253
+ 688,
254
+ 688,
255
+ 688,
256
+ 688,
257
+ 688,
258
+ 688,
259
+ 688,
260
+ 688,
261
+ 688
262
+ ],
263
+ [
264
+ 688,
265
+ 688,
266
+ 688,
267
+ 688,
268
+ 688,
269
+ 688,
270
+ 688,
271
+ 688,
272
+ 688,
273
+ 688,
274
+ 688,
275
+ 688,
276
+ 688,
277
+ 688,
278
+ 688,
279
+ 688
280
+ ],
281
+ [
282
+ 688,
283
+ 688,
284
+ 688,
285
+ 688,
286
+ 688,
287
+ 688,
288
+ 688,
289
+ 688,
290
+ 688,
291
+ 688,
292
+ 688,
293
+ 688,
294
+ 688,
295
+ 688,
296
+ 688,
297
+ 688
298
+ ],
299
+ [
300
+ 688,
301
+ 688,
302
+ 688,
303
+ 688,
304
+ 688,
305
+ 688,
306
+ 688,
307
+ 688,
308
+ 688,
309
+ 688,
310
+ 688,
311
+ 688,
312
+ 688,
313
+ 688,
314
+ 688,
315
+ 688
316
+ ],
317
+ [
318
+ 688,
319
+ 688,
320
+ 688,
321
+ 688,
322
+ 688,
323
+ 688,
324
+ 688,
325
+ 688,
326
+ 688,
327
+ 688,
328
+ 688,
329
+ 688,
330
+ 688,
331
+ 688,
332
+ 688,
333
+ 688
334
+ ],
335
+ [
336
+ 688,
337
+ 688,
338
+ 688,
339
+ 688,
340
+ 688,
341
+ 688,
342
+ 688,
343
+ 688,
344
+ 688,
345
+ 688,
346
+ 688,
347
+ 688,
348
+ 688,
349
+ 688,
350
+ 688,
351
+ 688
352
+ ],
353
+ [
354
+ 688,
355
+ 688,
356
+ 688,
357
+ 688,
358
+ 688,
359
+ 688,
360
+ 688,
361
+ 688,
362
+ 688,
363
+ 688,
364
+ 688,
365
+ 688,
366
+ 688,
367
+ 688,
368
+ 688,
369
+ 688
370
+ ],
371
+ [
372
+ 688,
373
+ 688,
374
+ 688,
375
+ 688,
376
+ 688,
377
+ 688,
378
+ 688,
379
+ 688,
380
+ 688,
381
+ 688,
382
+ 688,
383
+ 688,
384
+ 688,
385
+ 688,
386
+ 688,
387
+ 688
388
+ ],
389
+ [
390
+ 688,
391
+ 688,
392
+ 688,
393
+ 688,
394
+ 688,
395
+ 688,
396
+ 688,
397
+ 688,
398
+ 688,
399
+ 688,
400
+ 688,
401
+ 688,
402
+ 688,
403
+ 688,
404
+ 688,
405
+ 688
406
+ ],
407
+ [
408
+ 688,
409
+ 688,
410
+ 688,
411
+ 688,
412
+ 688,
413
+ 688,
414
+ 688,
415
+ 688,
416
+ 688,
417
+ 688,
418
+ 688,
419
+ 688,
420
+ 688,
421
+ 688,
422
+ 688,
423
+ 688
424
+ ],
425
+ [
426
+ 688,
427
+ 688,
428
+ 688,
429
+ 688,
430
+ 688,
431
+ 688,
432
+ 688,
433
+ 688,
434
+ 688,
435
+ 688,
436
+ 688,
437
+ 688,
438
+ 688,
439
+ 688,
440
+ 688,
441
+ 688
442
+ ],
443
+ [
444
+ 688,
445
+ 688,
446
+ 688,
447
+ 688,
448
+ 688,
449
+ 688,
450
+ 688,
451
+ 688,
452
+ 688,
453
+ 688,
454
+ 688,
455
+ 688,
456
+ 688,
457
+ 688,
458
+ 688,
459
+ 688
460
+ ],
461
+ [
462
+ 688,
463
+ 688,
464
+ 688,
465
+ 688,
466
+ 688,
467
+ 688,
468
+ 688,
469
+ 688,
470
+ 688,
471
+ 688,
472
+ 688,
473
+ 688,
474
+ 688,
475
+ 688,
476
+ 688,
477
+ 688
478
+ ],
479
+ [
480
+ 688,
481
+ 688,
482
+ 688,
483
+ 688,
484
+ 688,
485
+ 688,
486
+ 688,
487
+ 688,
488
+ 688,
489
+ 688,
490
+ 688,
491
+ 688,
492
+ 688,
493
+ 688,
494
+ 688,
495
+ 688
496
+ ],
497
+ [
498
+ 688,
499
+ 688,
500
+ 688,
501
+ 688,
502
+ 688,
503
+ 688,
504
+ 688,
505
+ 688,
506
+ 688,
507
+ 688,
508
+ 688,
509
+ 688,
510
+ 688,
511
+ 688,
512
+ 688,
513
+ 688
514
+ ],
515
+ [
516
+ 688,
517
+ 688,
518
+ 688,
519
+ 688,
520
+ 688,
521
+ 688,
522
+ 688,
523
+ 688,
524
+ 688,
525
+ 688,
526
+ 688,
527
+ 688,
528
+ 688,
529
+ 688,
530
+ 688,
531
+ 688
532
+ ],
533
+ [
534
+ 688,
535
+ 688,
536
+ 688,
537
+ 688,
538
+ 688,
539
+ 688,
540
+ 688,
541
+ 688,
542
+ 688,
543
+ 688,
544
+ 688,
545
+ 688,
546
+ 688,
547
+ 688,
548
+ 688,
549
+ 688
550
+ ],
551
+ [
552
+ 688,
553
+ 688,
554
+ 688,
555
+ 688,
556
+ 688,
557
+ 688,
558
+ 688,
559
+ 688,
560
+ 688,
561
+ 688,
562
+ 688,
563
+ 688,
564
+ 688,
565
+ 688,
566
+ 688,
567
+ 688
568
+ ],
569
+ [
570
+ 688,
571
+ 688,
572
+ 688,
573
+ 688,
574
+ 688,
575
+ 688,
576
+ 688,
577
+ 688,
578
+ 688,
579
+ 688,
580
+ 688,
581
+ 688,
582
+ 688,
583
+ 688,
584
+ 688,
585
+ 688
586
+ ],
587
+ [
588
+ 688,
589
+ 688,
590
+ 688,
591
+ 688,
592
+ 688,
593
+ 688,
594
+ 688,
595
+ 688,
596
+ 688,
597
+ 688,
598
+ 688,
599
+ 688,
600
+ 688,
601
+ 688,
602
+ 688,
603
+ 688
604
+ ],
605
+ [
606
+ 688,
607
+ 688,
608
+ 688,
609
+ 688,
610
+ 688,
611
+ 688,
612
+ 688,
613
+ 688,
614
+ 688,
615
+ 688,
616
+ 688,
617
+ 688,
618
+ 688,
619
+ 688,
620
+ 688,
621
+ 688
622
+ ]
623
+ ],
624
+ "tie_word_embeddings": false,
625
+ "torch_dtype": "bfloat16",
626
+ "transformers_version": "4.36.2",
627
+ "use_cache": true,
628
+ "vocab_size": 32000
629
+ }
configuration_llama_moe.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+
3
+
4
+ class LlamaMoEConfig(PretrainedConfig):
5
+ model_type = "llama_moe"
6
+ keys_to_ignore_at_inference = ["past_key_values"]
7
+
8
+ def __init__(
9
+ self,
10
+ vocab_size=32000,
11
+ hidden_size=4096,
12
+ intermediate_size=11008,
13
+ num_hidden_layers=32,
14
+ num_attention_heads=32,
15
+ num_key_value_heads=None,
16
+ hidden_act="silu",
17
+ max_position_embeddings=2048,
18
+ initializer_range=0.02,
19
+ rms_norm_eps=1e-6,
20
+ use_cache=True,
21
+ pad_token_id=0,
22
+ bos_token_id=1,
23
+ eos_token_id=2,
24
+ pretraining_tp=1,
25
+ tie_word_embeddings=False,
26
+ rope_theta=10000.0,
27
+ rope_scaling=None,
28
+ attention_bias=False,
29
+ attention_dropout=0.0,
30
+ # -------- moe expert configs --------
31
+ num_experts=16,
32
+ num_selects=4,
33
+ size_experts=None,
34
+ # -------- moe gate configs --------
35
+ gate_type="TopKBalancedNoisyGate",
36
+ gate_network="mlp",
37
+ gate_use_softmax=True,
38
+ gate_use_balance=True,
39
+ gate_balance_loss_weight=1e-2,
40
+ gate_add_noise=True,
41
+ # TopKBalancedNoisyGate
42
+ gate_noise_epsilon=1e-2,
43
+ # -------- moe calculator configs --------
44
+ calculator_type="UniversalCalculator",
45
+ multiply_gate_scores=True,
46
+ score_scale_factor=1.0,
47
+ add_weight_norm=False,
48
+ # SwitchDropTokenCalculator
49
+ drop_tokens=True,
50
+ dropped_padding="zero",
51
+ capacity_factor=1.25,
52
+ **kwargs,
53
+ ):
54
+ self.vocab_size = vocab_size
55
+ self.max_position_embeddings = max_position_embeddings
56
+ self.hidden_size = hidden_size
57
+ self.intermediate_size = intermediate_size
58
+ self.num_hidden_layers = num_hidden_layers
59
+ self.num_attention_heads = num_attention_heads
60
+ self.hidden_act = hidden_act
61
+ self.initializer_range = initializer_range
62
+ self.rms_norm_eps = rms_norm_eps
63
+ self.pretraining_tp = pretraining_tp
64
+ self.use_cache = use_cache
65
+ self.rope_theta = rope_theta
66
+ self.rope_scaling = rope_scaling
67
+ self._rope_scaling_validation()
68
+ self.attention_bias = attention_bias
69
+ self.attention_dropout = attention_dropout
70
+
71
+ self.num_experts = num_experts
72
+ self.num_selects = num_selects
73
+ self.size_experts = size_experts
74
+
75
+ self.gate_type = gate_type
76
+ self.gate_network = gate_network
77
+ self.gate_use_softmax = gate_use_softmax
78
+ self.gate_use_balance = gate_use_balance
79
+ self.gate_balance_loss_weight = gate_balance_loss_weight
80
+ self.gate_add_noise = gate_add_noise
81
+ self.gate_noise_epsilon = gate_noise_epsilon
82
+
83
+ self.calculator_type = calculator_type
84
+ self.multiply_gate_scores = multiply_gate_scores
85
+ self.score_scale_factor = score_scale_factor
86
+ self.add_weight_norm = add_weight_norm
87
+ self.drop_tokens = drop_tokens
88
+ self.dropped_padding = dropped_padding
89
+ self.capacity_factor = capacity_factor
90
+
91
+ # for backward compatibility
92
+ if num_key_value_heads is None:
93
+ num_key_value_heads = num_attention_heads
94
+
95
+ self.num_key_value_heads = num_key_value_heads
96
+
97
+ super().__init__(
98
+ pad_token_id=pad_token_id,
99
+ bos_token_id=bos_token_id,
100
+ eos_token_id=eos_token_id,
101
+ tie_word_embeddings=tie_word_embeddings,
102
+ **kwargs,
103
+ )
104
+
105
+ def _rope_scaling_validation(self):
106
+ """
107
+ Validate the `rope_scaling` configuration.
108
+ """
109
+ if self.rope_scaling is None:
110
+ return
111
+
112
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
113
+ raise ValueError(
114
+ "`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, "
115
+ f"got {self.rope_scaling}"
116
+ )
117
+ rope_scaling_type = self.rope_scaling.get("type", None)
118
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
119
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
120
+ raise ValueError(
121
+ f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
122
+ )
123
+ if (
124
+ rope_scaling_factor is None
125
+ or not isinstance(rope_scaling_factor, float)
126
+ or rope_scaling_factor <= 1.0
127
+ ):
128
+ raise ValueError(
129
+ f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}"
130
+ )
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "pad_token_id": 0,
6
+ "transformers_version": "4.36.2"
7
+ }
model-00001-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c3b9d76d82f9f5d6f8d03da589fae79a95b12c03f43ef644f4b068c68913b70b
3
+ size 4998589368
model-00002-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c0e4e04523c8457023d7f10c4865dd3c0ab56a768b5f118ae82826c6c88e9e2f
3
+ size 4984439672
model-00003-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0f89bd5d2e817eaa4ff8eb6d4222189dbaaa1962173bb652f7d426a0e7b31c1d
3
+ size 3502447520
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_llama_moe_hf.py ADDED
@@ -0,0 +1,1690 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+ from dataclasses import dataclass
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import torch.utils.checkpoint
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torch.distributions.normal import Normal
11
+ from transformers.modeling_outputs import (
12
+ CausalLMOutputWithPast,
13
+ )
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.activations import ACT2FN
16
+ from transformers.utils import ModelOutput, logging
17
+ from transformers.cache_utils import Cache, DynamicCache
18
+ from transformers.modeling_attn_mask_utils import (
19
+ AttentionMaskConverter,
20
+ _prepare_4d_attention_mask,
21
+ _prepare_4d_causal_attention_mask,
22
+ _prepare_4d_causal_attention_mask_for_sdpa,
23
+ )
24
+ from transformers.utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10
25
+
26
+ from .configuration_llama_moe import LlamaMoEConfig
27
+
28
+
29
+ if is_flash_attn_2_available():
30
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
31
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
32
+
33
+
34
+ def _get_unpad_data(attention_mask):
35
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
36
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
37
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
38
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
39
+ return (
40
+ indices,
41
+ cu_seqlens,
42
+ max_seqlen_in_batch,
43
+ )
44
+
45
+
46
+ logger = logging.get_logger(__name__)
47
+
48
+ _CONFIG_FOR_DOC = "LlamaMoEConfig"
49
+
50
+
51
+ @dataclass
52
+ class CalculatorOutput(ModelOutput):
53
+ hidden_states: Optional[torch.FloatTensor] = None
54
+ num_dropped_tokens: Optional[int] = None
55
+
56
+
57
+ @dataclass
58
+ class BaseMoEModelOutputWithPast(ModelOutput):
59
+ """
60
+ Args:
61
+ num_dropped_tokens: layer idx to the number of dropped tokens
62
+ """
63
+
64
+ last_hidden_state: torch.FloatTensor = None
65
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
66
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
67
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
68
+ balance_loss: Optional[float] = None
69
+ num_dropped_tokens: Optional[Tuple[torch.Tensor]] = None
70
+ gate_load: Optional[Tuple[list]] = None
71
+ gate_importance: Optional[Tuple[list]] = None
72
+
73
+
74
+ @dataclass
75
+ class MoECausalLMOutputWithPast(CausalLMOutputWithPast):
76
+ balance_loss: Optional[float] = None
77
+ num_dropped_tokens: Optional[Tuple[int]] = None
78
+ gate_load: Optional[Tuple[list[torch.Tensor]]] = None
79
+ gate_importance: Optional[Tuple[list[torch.Tensor]]] = None
80
+
81
+
82
+ @dataclass
83
+ class MoEMlpOutput(ModelOutput):
84
+ hidden_states: Optional[torch.FloatTensor] = None
85
+ balance_loss: Optional[torch.FloatTensor] = None
86
+ num_dropped_tokens: Optional[int] = None
87
+ gate_load: Optional[list] = None
88
+ gate_importance: Optional[list] = None
89
+
90
+
91
+ def _make_causal_mask(
92
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
93
+ ):
94
+ """
95
+ Make causal mask used for bi-directional self-attention.
96
+ """
97
+ bsz, tgt_len = input_ids_shape
98
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
99
+ mask_cond = torch.arange(mask.size(-1), device=device)
100
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
101
+ mask = mask.to(dtype)
102
+
103
+ if past_key_values_length > 0:
104
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
105
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
106
+
107
+
108
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
109
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
110
+ """
111
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
112
+ """
113
+ bsz, src_len = mask.size()
114
+ tgt_len = tgt_len if tgt_len is not None else src_len
115
+
116
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
117
+
118
+ inverted_mask = 1.0 - expanded_mask
119
+
120
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
121
+
122
+
123
+ class LlamaRMSNorm(nn.Module):
124
+ def __init__(self, hidden_size, eps=1e-6):
125
+ """
126
+ LlamaRMSNorm is equivalent to T5LayerNorm
127
+ """
128
+ super().__init__()
129
+ self.weight = nn.Parameter(torch.ones(hidden_size))
130
+ self.variance_epsilon = eps
131
+
132
+ def forward(self, hidden_states):
133
+ input_dtype = hidden_states.dtype
134
+ hidden_states = hidden_states.to(torch.float32)
135
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
136
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
137
+ return self.weight * hidden_states.to(input_dtype)
138
+
139
+
140
+ class LlamaRotaryEmbedding(torch.nn.Module):
141
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
142
+ super().__init__()
143
+
144
+ self.dim = dim
145
+ self.max_position_embeddings = max_position_embeddings
146
+ self.base = base
147
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
148
+ self.register_buffer("inv_freq", inv_freq)
149
+
150
+ # Build here to make `torch.jit.trace` work.
151
+ self._set_cos_sin_cache(
152
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
153
+ )
154
+
155
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
156
+ self.max_seq_len_cached = seq_len
157
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
158
+
159
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
160
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
161
+ emb = torch.cat((freqs, freqs), dim=-1)
162
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
163
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
164
+
165
+ def forward(self, x, seq_len=None):
166
+ # x: [bs, num_attention_heads, seq_len, head_size]
167
+ if seq_len > self.max_seq_len_cached:
168
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
169
+
170
+ return (
171
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
172
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
173
+ )
174
+
175
+
176
+ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
177
+ """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
178
+
179
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
180
+ self.scaling_factor = scaling_factor
181
+ super().__init__(dim, max_position_embeddings, base, device)
182
+
183
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
184
+ self.max_seq_len_cached = seq_len
185
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
186
+ t = t / self.scaling_factor
187
+
188
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
189
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
190
+ emb = torch.cat((freqs, freqs), dim=-1)
191
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
192
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
193
+
194
+
195
+ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
196
+ """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
197
+
198
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
199
+ self.scaling_factor = scaling_factor
200
+ super().__init__(dim, max_position_embeddings, base, device)
201
+
202
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
203
+ self.max_seq_len_cached = seq_len
204
+
205
+ if seq_len > self.max_position_embeddings:
206
+ base = self.base * (
207
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
208
+ ) ** (self.dim / (self.dim - 2))
209
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
210
+ self.register_buffer("inv_freq", inv_freq)
211
+
212
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
213
+
214
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
215
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
216
+ emb = torch.cat((freqs, freqs), dim=-1)
217
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
218
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
219
+
220
+
221
+ def rotate_half(x):
222
+ """Rotates half the hidden dims of the input."""
223
+ x1 = x[..., : x.shape[-1] // 2]
224
+ x2 = x[..., x.shape[-1] // 2 :]
225
+ return torch.cat((-x2, x1), dim=-1)
226
+
227
+
228
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
229
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
230
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
231
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
232
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
233
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
234
+ q_embed = (q * cos) + (rotate_half(q) * sin)
235
+ k_embed = (k * cos) + (rotate_half(k) * sin)
236
+ return q_embed, k_embed
237
+
238
+
239
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
240
+ """
241
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
242
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
243
+ """
244
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
245
+ if n_rep == 1:
246
+ return hidden_states
247
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
248
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
249
+
250
+
251
+ class LlamaAttention(nn.Module):
252
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
253
+
254
+ def __init__(self, config: LlamaMoEConfig, layer_idx: Optional[int] = None):
255
+ super().__init__()
256
+ self.config = config
257
+ self.layer_idx = layer_idx
258
+ if layer_idx is None:
259
+ logger.warning_once(
260
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
261
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
262
+ "when creating this class."
263
+ )
264
+
265
+ self.attention_dropout = config.attention_dropout
266
+ self.hidden_size = config.hidden_size
267
+ self.num_heads = config.num_attention_heads
268
+ self.head_dim = self.hidden_size // self.num_heads
269
+ self.num_key_value_heads = config.num_key_value_heads
270
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
271
+ self.max_position_embeddings = config.max_position_embeddings
272
+ self.rope_theta = config.rope_theta
273
+ self.is_causal = True
274
+
275
+ if (self.head_dim * self.num_heads) != self.hidden_size:
276
+ raise ValueError(
277
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
278
+ f" and `num_heads`: {self.num_heads})."
279
+ )
280
+
281
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
282
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
283
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
284
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
285
+ self._init_rope()
286
+
287
+ def _init_rope(self):
288
+ if self.config.rope_scaling is None:
289
+ self.rotary_emb = LlamaRotaryEmbedding(
290
+ self.head_dim,
291
+ max_position_embeddings=self.max_position_embeddings,
292
+ base=self.rope_theta,
293
+ )
294
+ else:
295
+ scaling_type = self.config.rope_scaling["type"]
296
+ scaling_factor = self.config.rope_scaling["factor"]
297
+ if scaling_type == "linear":
298
+ self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
299
+ self.head_dim,
300
+ max_position_embeddings=self.max_position_embeddings,
301
+ scaling_factor=scaling_factor,
302
+ base=self.rope_theta,
303
+ )
304
+ elif scaling_type == "dynamic":
305
+ self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
306
+ self.head_dim,
307
+ max_position_embeddings=self.max_position_embeddings,
308
+ scaling_factor=scaling_factor,
309
+ base=self.rope_theta,
310
+ )
311
+ else:
312
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
313
+
314
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
315
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
316
+
317
+ def forward(
318
+ self,
319
+ hidden_states: torch.Tensor,
320
+ attention_mask: Optional[torch.Tensor] = None,
321
+ position_ids: Optional[torch.LongTensor] = None,
322
+ past_key_value: Optional[Cache] = None,
323
+ output_attentions: bool = False,
324
+ use_cache: bool = False,
325
+ **kwargs,
326
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
327
+ if "padding_mask" in kwargs:
328
+ warnings.warn(
329
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
330
+ )
331
+
332
+ bsz, q_len, _ = hidden_states.size()
333
+
334
+ if self.config.pretraining_tp > 1:
335
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
336
+ query_slices = self.q_proj.weight.split(
337
+ (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
338
+ )
339
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
340
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
341
+
342
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
343
+ query_states = torch.cat(query_states, dim=-1)
344
+
345
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
346
+ key_states = torch.cat(key_states, dim=-1)
347
+
348
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
349
+ value_states = torch.cat(value_states, dim=-1)
350
+
351
+ else:
352
+ query_states = self.q_proj(hidden_states)
353
+ key_states = self.k_proj(hidden_states)
354
+ value_states = self.v_proj(hidden_states)
355
+
356
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
357
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
358
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
359
+
360
+ kv_seq_len = key_states.shape[-2]
361
+ if past_key_value is not None:
362
+ if self.layer_idx is None:
363
+ raise ValueError(
364
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
365
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
366
+ "with a layer index."
367
+ )
368
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
369
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
370
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
371
+
372
+ if past_key_value is not None:
373
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
374
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
375
+
376
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
377
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
378
+
379
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
380
+
381
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
382
+ raise ValueError(
383
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
384
+ f" {attn_weights.size()}"
385
+ )
386
+
387
+ if attention_mask is not None:
388
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
389
+ raise ValueError(
390
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
391
+ )
392
+ attn_weights = attn_weights + attention_mask
393
+
394
+ # upcast attention to fp32
395
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
396
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
397
+ attn_output = torch.matmul(attn_weights, value_states)
398
+
399
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
400
+ raise ValueError(
401
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
402
+ f" {attn_output.size()}"
403
+ )
404
+
405
+ attn_output = attn_output.transpose(1, 2).contiguous()
406
+
407
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
408
+
409
+ if self.config.pretraining_tp > 1:
410
+ attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
411
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
412
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
413
+ else:
414
+ attn_output = self.o_proj(attn_output)
415
+
416
+ if not output_attentions:
417
+ attn_weights = None
418
+
419
+ return attn_output, attn_weights, past_key_value
420
+
421
+
422
+ class LlamaFlashAttention2(LlamaAttention):
423
+ """
424
+ Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
425
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
426
+ flash attention and deal with padding tokens in case the input contains any of them.
427
+ """
428
+
429
+ def __init__(self, *args, **kwargs):
430
+ super().__init__(*args, **kwargs)
431
+
432
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
433
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
434
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
435
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
436
+
437
+ def forward(
438
+ self,
439
+ hidden_states: torch.Tensor,
440
+ attention_mask: Optional[torch.LongTensor] = None,
441
+ position_ids: Optional[torch.LongTensor] = None,
442
+ past_key_value: Optional[Cache] = None,
443
+ output_attentions: bool = False,
444
+ use_cache: bool = False,
445
+ **kwargs,
446
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
447
+ # LlamaFlashAttention2 attention does not support output_attentions
448
+ if "padding_mask" in kwargs:
449
+ warnings.warn(
450
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
451
+ )
452
+
453
+ # overwrite attention_mask with padding_mask
454
+ attention_mask = kwargs.pop("padding_mask")
455
+
456
+ output_attentions = False
457
+
458
+ bsz, q_len, _ = hidden_states.size()
459
+
460
+ query_states = self.q_proj(hidden_states)
461
+ key_states = self.k_proj(hidden_states)
462
+ value_states = self.v_proj(hidden_states)
463
+
464
+ # Flash attention requires the input to have the shape
465
+ # batch_size x seq_length x head_dim x hidden_dim
466
+ # therefore we just need to keep the original shape
467
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
468
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
469
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
470
+
471
+ kv_seq_len = key_states.shape[-2]
472
+ if past_key_value is not None:
473
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
474
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
475
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
476
+
477
+ if past_key_value is not None:
478
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
479
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
480
+
481
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
482
+ # to be able to avoid many of these transpose/reshape/view.
483
+ query_states = query_states.transpose(1, 2)
484
+ key_states = key_states.transpose(1, 2)
485
+ value_states = value_states.transpose(1, 2)
486
+
487
+ dropout_rate = self.attention_dropout if self.training else 0.0
488
+
489
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
490
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
491
+ # cast them back in the correct dtype just to be sure everything works as expected.
492
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
493
+ # in fp32. (LlamaRMSNorm handles it correctly)
494
+
495
+ input_dtype = query_states.dtype
496
+ if input_dtype == torch.float32:
497
+ if torch.is_autocast_enabled():
498
+ target_dtype = torch.get_autocast_gpu_dtype()
499
+ # Handle the case where the model is quantized
500
+ elif hasattr(self.config, "_pre_quantization_dtype"):
501
+ target_dtype = self.config._pre_quantization_dtype
502
+ else:
503
+ target_dtype = self.q_proj.weight.dtype
504
+
505
+ logger.warning_once(
506
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
507
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
508
+ f" {target_dtype}."
509
+ )
510
+
511
+ query_states = query_states.to(target_dtype)
512
+ key_states = key_states.to(target_dtype)
513
+ value_states = value_states.to(target_dtype)
514
+
515
+ attn_output = self._flash_attention_forward(
516
+ query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
517
+ )
518
+
519
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
520
+ attn_output = self.o_proj(attn_output)
521
+
522
+ if not output_attentions:
523
+ attn_weights = None
524
+
525
+ return attn_output, attn_weights, past_key_value
526
+
527
+ def _flash_attention_forward(
528
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
529
+ ):
530
+ """
531
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
532
+ first unpad the input, then computes the attention scores and pad the final attention scores.
533
+
534
+ Args:
535
+ query_states (`torch.Tensor`):
536
+ Input query states to be passed to Flash Attention API
537
+ key_states (`torch.Tensor`):
538
+ Input key states to be passed to Flash Attention API
539
+ value_states (`torch.Tensor`):
540
+ Input value states to be passed to Flash Attention API
541
+ attention_mask (`torch.Tensor`):
542
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
543
+ position of padding tokens and 1 for the position of non-padding tokens.
544
+ dropout (`int`, *optional*):
545
+ Attention dropout
546
+ softmax_scale (`float`, *optional*):
547
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
548
+ """
549
+ if not self._flash_attn_uses_top_left_mask:
550
+ causal = self.is_causal
551
+ else:
552
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
553
+ causal = self.is_causal and query_length != 1
554
+
555
+ # Contains at least one padding token in the sequence
556
+ if attention_mask is not None:
557
+ batch_size = query_states.shape[0]
558
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
559
+ query_states, key_states, value_states, attention_mask, query_length
560
+ )
561
+
562
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
563
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
564
+
565
+ attn_output_unpad = flash_attn_varlen_func(
566
+ query_states,
567
+ key_states,
568
+ value_states,
569
+ cu_seqlens_q=cu_seqlens_q,
570
+ cu_seqlens_k=cu_seqlens_k,
571
+ max_seqlen_q=max_seqlen_in_batch_q,
572
+ max_seqlen_k=max_seqlen_in_batch_k,
573
+ dropout_p=dropout,
574
+ softmax_scale=softmax_scale,
575
+ causal=causal,
576
+ )
577
+
578
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
579
+ else:
580
+ attn_output = flash_attn_func(
581
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
582
+ )
583
+
584
+ return attn_output
585
+
586
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
587
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
588
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
589
+
590
+ key_layer = index_first_axis(
591
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
592
+ )
593
+ value_layer = index_first_axis(
594
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
595
+ )
596
+ if query_length == kv_seq_len:
597
+ query_layer = index_first_axis(
598
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
599
+ )
600
+ cu_seqlens_q = cu_seqlens_k
601
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
602
+ indices_q = indices_k
603
+ elif query_length == 1:
604
+ max_seqlen_in_batch_q = 1
605
+ cu_seqlens_q = torch.arange(
606
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
607
+ ) # There is a memcpy here, that is very bad.
608
+ indices_q = cu_seqlens_q[:-1]
609
+ query_layer = query_layer.squeeze(1)
610
+ else:
611
+ # The -q_len: slice assumes left padding.
612
+ attention_mask = attention_mask[:, -query_length:]
613
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
614
+
615
+ return (
616
+ query_layer,
617
+ key_layer,
618
+ value_layer,
619
+ indices_q,
620
+ (cu_seqlens_q, cu_seqlens_k),
621
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
622
+ )
623
+
624
+
625
+ class LlamaSdpaAttention(LlamaAttention):
626
+ """
627
+ Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
628
+ `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
629
+ SDPA API.
630
+ """
631
+
632
+ # Adapted from LlamaAttention.forward
633
+ def forward(
634
+ self,
635
+ hidden_states: torch.Tensor,
636
+ attention_mask: Optional[torch.Tensor] = None,
637
+ position_ids: Optional[torch.LongTensor] = None,
638
+ past_key_value: Optional[Cache] = None,
639
+ output_attentions: bool = False,
640
+ use_cache: bool = False,
641
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
642
+ if output_attentions:
643
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
644
+ logger.warning_once(
645
+ "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
646
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
647
+ )
648
+ return super().forward(
649
+ hidden_states=hidden_states,
650
+ attention_mask=attention_mask,
651
+ position_ids=position_ids,
652
+ past_key_value=past_key_value,
653
+ output_attentions=output_attentions,
654
+ use_cache=use_cache,
655
+ )
656
+
657
+ bsz, q_len, _ = hidden_states.size()
658
+
659
+ query_states = self.q_proj(hidden_states)
660
+ key_states = self.k_proj(hidden_states)
661
+ value_states = self.v_proj(hidden_states)
662
+
663
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
664
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
665
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
666
+
667
+ kv_seq_len = key_states.shape[-2]
668
+ if past_key_value is not None:
669
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
670
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
671
+
672
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
673
+
674
+ if past_key_value is not None:
675
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
676
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
677
+
678
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
679
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
680
+
681
+ if attention_mask is not None:
682
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
683
+ raise ValueError(
684
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
685
+ )
686
+
687
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
688
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
689
+ if query_states.device.type == "cuda" and attention_mask is not None:
690
+ query_states = query_states.contiguous()
691
+ key_states = key_states.contiguous()
692
+ value_states = value_states.contiguous()
693
+
694
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
695
+ query_states,
696
+ key_states,
697
+ value_states,
698
+ attn_mask=attention_mask,
699
+ dropout_p=self.attention_dropout if self.training else 0.0,
700
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
701
+ is_causal=self.is_causal and attention_mask is None and q_len > 1,
702
+ )
703
+
704
+ attn_output = attn_output.transpose(1, 2).contiguous()
705
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
706
+
707
+ attn_output = self.o_proj(attn_output)
708
+
709
+ return attn_output, None, past_key_value
710
+
711
+
712
+ LLAMA_ATTENTION_CLASSES = {
713
+ "eager": LlamaAttention,
714
+ "flash_attention_2": LlamaFlashAttention2,
715
+ "sdpa": LlamaSdpaAttention,
716
+ }
717
+
718
+
719
+ class TopKBalancedNoisyGate(nn.Module):
720
+ def __init__(
721
+ self,
722
+ input_size,
723
+ num_experts,
724
+ num_selects,
725
+ gate_network="mlp",
726
+ use_softmax=True,
727
+ use_balance=True,
728
+ balance_loss_weight=1e-2,
729
+ add_noise=True,
730
+ noise_epsilon=1e-2,
731
+ ):
732
+ super(TopKBalancedNoisyGate, self).__init__()
733
+ assert num_selects <= num_experts
734
+ self.input_size = input_size
735
+ self.num_experts = num_experts
736
+ self.num_selects = num_selects
737
+
738
+ self.gate_network_type = gate_network
739
+ self.gate_network = self.get_gate_network(gate_network, input_size, num_experts)
740
+
741
+ self.use_softmax = use_softmax
742
+ self.softmax = nn.Softmax(1)
743
+
744
+ self.use_balance = use_balance
745
+ self.balance_loss_weight = balance_loss_weight
746
+
747
+ # add_noise
748
+ self.add_noise = add_noise
749
+ self.noise_epsilon = noise_epsilon
750
+ self.warned = False
751
+ if self.add_noise:
752
+ self.weight_noise = nn.Linear(input_size, num_experts, bias=False)
753
+ self.weight_noise.weight.data = torch.zeros(
754
+ (num_experts, input_size),
755
+ requires_grad=True,
756
+ device=self.weight_noise.weight.data.device,
757
+ dtype=self.weight_noise.weight.data.dtype,
758
+ )
759
+ self.mean = 0.0
760
+ self.std = 1.0
761
+ self.normal = Normal(self.mean, self.std)
762
+ self.softplus = nn.Softplus()
763
+
764
+ self.reset_parameters()
765
+
766
+ def get_gate_network(self, gate_type, input_size, num_experts):
767
+ gate_type = gate_type.lower()
768
+
769
+ if gate_type == "linear":
770
+ gate_network = nn.Linear(input_size, num_experts, bias=False)
771
+ nn.init.zeros_(gate_network.weight)
772
+ elif gate_type == "mlp":
773
+ gate_network = torch.nn.Sequential(
774
+ torch.nn.Linear(input_size, num_experts, bias=False),
775
+ torch.nn.Tanh(),
776
+ torch.nn.Linear(num_experts, num_experts, bias=False),
777
+ )
778
+ else:
779
+ raise ValueError(f'Unexpected gate_type: {gate_type}.')
780
+
781
+ return gate_network
782
+
783
+ def reset_gate_network(self):
784
+ if "gate_network_type" not in vars(self):
785
+ raise KeyError(f"{type(self)} does not have a gate network.")
786
+ else:
787
+ self.gate_network = self.get_gate_network(
788
+ self.gate_network_type, self.input_size, self.num_experts
789
+ )
790
+
791
+ def reset_parameters(self):
792
+ if self.add_noise:
793
+ nn.init.zeros_(self.weight_noise.weight)
794
+ # nn.init.zeros_(self.weight_noise)
795
+
796
+ def cv_squared(self, x, eps=1e-10):
797
+ """The squared coefficient of variation of a sample.
798
+ Useful as a loss to encourage a positive distribution to be more uniform.
799
+ Epsilons added for numerical stability.
800
+ Returns 0 for an empty Tensor.
801
+ Args:
802
+ x: a `Tensor`.
803
+ Returns:
804
+ a `Scalar`.s
805
+ """
806
+ if x.shape[0] == 1:
807
+ return torch.tensor(0.0, device=x.device)
808
+ return x.float().var() / (x.float().mean() ** 2 + eps)
809
+
810
+ def forward(self, x):
811
+ logits_gate = self.gate_network(x)
812
+ if self.training and self.add_noise:
813
+ noise_mm = self.weight_noise(x)
814
+ noise_control = self.softplus(noise_mm) + self.noise_epsilon
815
+ logits_noise = torch.randn_like(logits_gate) * noise_control
816
+ logits = logits_gate + logits_noise
817
+ else:
818
+ logits = logits_gate
819
+
820
+ top_logits, top_indices = logits.topk(min(self.num_selects + 1, self.num_experts), dim=1) # 选择并排序前k+1个权重
821
+ top_k_logits = top_logits[:, :self.num_selects]
822
+ top_k_indices = top_indices[:, :self.num_selects]
823
+ top_k_scores = self.softmax(top_k_logits.to(torch.float32)) if self.use_softmax else top_k_logits
824
+ top_k_scores = top_k_scores.to(logits.dtype)
825
+
826
+ zeros = torch.zeros_like(logits, requires_grad=True, device=logits.device)
827
+ scores_filtered = zeros.scatter(dim=1, index=top_k_indices, src=top_k_scores) # shape(batch_size, num_experts)
828
+ importance = scores_filtered.sum(0) # shape(num_experts)
829
+
830
+ if self.training:
831
+ if self.add_noise and self.num_selects != self.num_experts:
832
+ batch_size = top_logits.size(0)
833
+ m = top_logits.size(1)
834
+ top_values_flat = top_logits.flatten()
835
+ threshold_positions_if_in = torch.arange(batch_size, device=x.device) * m + self.num_selects
836
+ threshold_if_in = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_in), 1)
837
+ is_in = torch.gt(logits_noise, threshold_if_in)
838
+ threshold_positions_if_out = threshold_positions_if_in - 1
839
+ threshold_if_out = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_out), 1)
840
+ # is each value currently in the top k.
841
+ prob_if_in = self.normal.cdf((logits_gate - threshold_if_in) / noise_control)
842
+ prob_if_out = self.normal.cdf((logits_gate - threshold_if_out) / noise_control)
843
+ prob = torch.where(is_in, prob_if_in, prob_if_out)
844
+ load = prob.sum(0)
845
+ else:
846
+ load = (scores_filtered > 0).sum(0)
847
+ if not self.add_noise and not self.warned:
848
+ warnings.warn('Gradient-trackable implementation for load calculation is only available when "add_noise=True". '
849
+ 'Training without noise will block the gradient from "load" path and lead to inconsistency in optimization objectives.')
850
+ self.warned = True
851
+ else:
852
+ load = (scores_filtered > 0).sum(0)
853
+
854
+ if self.use_balance:
855
+ balance_loss = self.cv_squared(importance) + self.cv_squared(load)
856
+ balance_loss *= self.balance_loss_weight
857
+ else:
858
+ balance_loss = torch.tensor(-100.0, device=x.device)
859
+
860
+ return {
861
+ "topK_indices": top_k_indices,
862
+ "topK_scores": top_k_scores,
863
+ "balance_loss": balance_loss,
864
+ "load": load,
865
+ "importance": importance,
866
+ }
867
+
868
+
869
+ class LinearGLUExperts(nn.Module):
870
+ """
871
+ Modified from transformers.models.llama.modeling_llama.LlamaMLP
872
+ """
873
+
874
+ __constants__ = [
875
+ "bias",
876
+ "in_features",
877
+ "hidden_features",
878
+ "out_features",
879
+ "hidden_act",
880
+ "num_experts",
881
+ "size_experts",
882
+ ]
883
+
884
+ def __init__(
885
+ self,
886
+ in_features,
887
+ hidden_features,
888
+ out_features,
889
+ hidden_act,
890
+ num_experts,
891
+ size_experts=None,
892
+ bias=True,
893
+ device=None,
894
+ dtype=None,
895
+ ):
896
+ factory_kwargs = {"device": device, "dtype": dtype}
897
+ super(LinearGLUExperts, self).__init__()
898
+ self.in_features = in_features
899
+ self.hidden_features = hidden_features
900
+ self.out_features = out_features
901
+ self.hidden_act = hidden_act
902
+ self.num_experts = num_experts
903
+
904
+ if size_experts is None:
905
+ # all experts share the same number of hidden neurons
906
+ assert hidden_features % num_experts == 0
907
+ size_per_expert = hidden_features // num_experts
908
+ size_experts = [size_per_expert for _ in range(num_experts)]
909
+ else:
910
+ # use specified expert sizes
911
+ assert (
912
+ len(size_experts) == num_experts
913
+ and sum(size_experts) == hidden_features
914
+ )
915
+ self.size_experts = size_experts
916
+
917
+ self.act_fn = ACT2FN[hidden_act]
918
+
919
+ self.weight_gate = nn.ParameterList()
920
+ self.weight_up = nn.ParameterList()
921
+ self.weight_down = nn.ParameterList()
922
+
923
+ for i in range(num_experts):
924
+ # this matrix will be transposed when performing linear forwarding
925
+ this_expert_weight_gate = nn.Parameter(
926
+ torch.empty((size_experts[i], in_features), **factory_kwargs)
927
+ )
928
+ # this matrix will be transposed when performing linear forwarding
929
+ this_expert_weight_up = nn.Parameter(
930
+ torch.empty((size_experts[i], in_features), **factory_kwargs)
931
+ )
932
+ # this matrix will be transposed when performing linear forwarding
933
+ this_expert_weight_down = nn.Parameter(
934
+ torch.empty((out_features, size_experts[i]), **factory_kwargs)
935
+ )
936
+ self.weight_gate.append(this_expert_weight_gate)
937
+ self.weight_up.append(this_expert_weight_up)
938
+ self.weight_down.append(this_expert_weight_down)
939
+
940
+ if bias:
941
+ self.bias_gate = nn.ParameterList()
942
+ self.bias_up = nn.ParameterList()
943
+ self.bias_down = nn.ParameterList()
944
+
945
+ for i in range(num_experts):
946
+ this_expert_bias_gate = nn.Parameter(
947
+ torch.empty((size_experts[i],), **factory_kwargs)
948
+ )
949
+ this_expert_bias_up = nn.Parameter(
950
+ torch.empty((size_experts[i],), **factory_kwargs)
951
+ )
952
+ this_expert_bias_down = nn.Parameter(
953
+ torch.empty((out_features,), **factory_kwargs)
954
+ )
955
+ self.bias_gate.append(this_expert_bias_gate)
956
+ self.bias_up.append(this_expert_bias_up)
957
+ self.bias_down.append(this_expert_bias_down)
958
+ else:
959
+ self.register_parameter("bias_gate", None)
960
+ self.register_parameter("bias_up", None)
961
+ self.register_parameter("bias_down", None)
962
+
963
+ self.reset_parameters()
964
+
965
+ def reset_parameters(self):
966
+ for i in range(self.num_experts):
967
+ nn.init.kaiming_uniform_(self.weight_gate[i], a=math.sqrt(5))
968
+ nn.init.kaiming_uniform_(self.weight_up[i], a=math.sqrt(5))
969
+ nn.init.kaiming_uniform_(self.weight_down[i], a=math.sqrt(5))
970
+ if self.bias_gate is not None:
971
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight_gate[i])
972
+ bound = 1 / math.sqrt(fan_in)
973
+ nn.init.uniform_(self.bias_gate[i], -bound, bound)
974
+ if self.bias_up is not None:
975
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight_up[i])
976
+ bound = 1 / math.sqrt(fan_in)
977
+ nn.init.uniform_(self.bias_up[i], -bound, bound)
978
+ if self.bias_down is not None:
979
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight_down[i])
980
+ bound = 1 / math.sqrt(fan_in)
981
+ nn.init.uniform_(self.bias_down[i], -bound, bound)
982
+
983
+ def forward(self, input, i):
984
+ gate = self.act_fn(
985
+ F.linear(
986
+ input,
987
+ self.weight_gate[i],
988
+ self.bias_gate[i] if self.bias_gate is not None else None,
989
+ )
990
+ )
991
+ up = F.linear(
992
+ input,
993
+ self.weight_up[i],
994
+ self.bias_up[i] if self.bias_up is not None else None,
995
+ )
996
+ down = F.linear(
997
+ gate * up,
998
+ self.weight_down[i],
999
+ self.bias_down[i] if self.bias_down is not None else None,
1000
+ )
1001
+ return down
1002
+
1003
+ def extra_repr(self):
1004
+ return (
1005
+ "in_features={}, hidden_features={}, out_features={}, hidden_act={},"
1006
+ " num_experts={}, size_experts={}, bias={}".format(
1007
+ self.in_features,
1008
+ self.hidden_features,
1009
+ self.out_features,
1010
+ self.hidden_act,
1011
+ self.num_experts,
1012
+ self.size_experts,
1013
+ self.bias_gate is not None,
1014
+ )
1015
+ )
1016
+
1017
+
1018
+ class UniversalCalculator(nn.Module):
1019
+ def __init__(
1020
+ self,
1021
+ experts: LinearGLUExperts,
1022
+ multiply_gate_scores=True,
1023
+ score_scale_factor=1.0,
1024
+ add_weight_norm: bool = False,
1025
+ ):
1026
+ super(UniversalCalculator, self).__init__()
1027
+ self.experts = experts
1028
+ # TODO (zhutong): use vmap to boost the training efficiency
1029
+ # self.experts_vmap = torch.vmap(self.experts)
1030
+ self.multiply_gate_scores = multiply_gate_scores
1031
+ self.score_scale_factor = score_scale_factor
1032
+ self.num_experts = experts.num_experts
1033
+ self.mlp_norm = None
1034
+ if multiply_gate_scores and add_weight_norm:
1035
+ raise NotImplementedError
1036
+
1037
+ def reset_experts(self):
1038
+ self.experts.reset_parameters()
1039
+
1040
+ def forward(
1041
+ self, x, topK_indices, topK_scores, expert_batch_size=None, **kwargs
1042
+ ) -> CalculatorOutput:
1043
+ batch_size = topK_indices.size(0) # topK_indices: (bsz*seq_len, num_selects)
1044
+ num_selects = topK_indices.size(1)
1045
+ topK_indices = topK_indices.flatten() # shape(batch_size*num_selects)
1046
+ topK_scores = topK_scores.flatten() # shape(batch_size*num_selects)
1047
+ batch_indices = torch.arange(
1048
+ batch_size, device=topK_scores.device
1049
+ ).repeat_interleave(num_selects)
1050
+
1051
+ _, index_sorted_topK_indices = topK_indices.sort(0)
1052
+
1053
+ sorted_topK_scores = topK_scores.index_select(0, index_sorted_topK_indices)
1054
+ sorted_batch_indices = batch_indices.index_select(0, index_sorted_topK_indices)
1055
+
1056
+ if expert_batch_size is None:
1057
+ expert_batch_size = topK_indices.bincount(
1058
+ minlength=self.num_experts
1059
+ ).tolist()
1060
+
1061
+ sorted_x = x.index_select(0, sorted_batch_indices)
1062
+ split_x = torch.split(sorted_x, expert_batch_size, dim=0)
1063
+
1064
+ expert_outputs = [
1065
+ self.experts(split_x[i], i)
1066
+ for i in range(self.num_experts)
1067
+ if split_x[i].shape[0] > 0
1068
+ ]
1069
+
1070
+ # (bsz*seq_len*num_selects, hidden_size)
1071
+ cat_expert_outputs = torch.cat(expert_outputs, 0)
1072
+ output_dim = cat_expert_outputs.size(1)
1073
+ if self.multiply_gate_scores:
1074
+ if self.mlp_norm is None:
1075
+ cat_expert_outputs = torch.mul(
1076
+ cat_expert_outputs,
1077
+ sorted_topK_scores.reshape(-1, 1) * self.score_scale_factor,
1078
+ )
1079
+ # cat_expert_outputs = torch.mul(cat_expert_outputs, sorted_topK_scores.reshape(-1, 1) * 1.0)
1080
+ else:
1081
+ cat_expert_outputs = torch.mul(
1082
+ cat_expert_outputs, sorted_topK_scores.reshape(-1, 1)
1083
+ )
1084
+ cat_expert_outputs = self.mlp_norm(cat_expert_outputs)
1085
+
1086
+ zeros = torch.zeros(
1087
+ (batch_size, output_dim),
1088
+ device=cat_expert_outputs.device,
1089
+ dtype=cat_expert_outputs.dtype,
1090
+ )
1091
+ y = zeros.index_add(0, sorted_batch_indices, cat_expert_outputs)
1092
+
1093
+ return CalculatorOutput(hidden_states=y, num_dropped_tokens=torch.tensor(-1.0))
1094
+
1095
+
1096
+ class BaseMoELayer(nn.Module):
1097
+ def __init__(self):
1098
+ super(BaseMoELayer, self).__init__()
1099
+
1100
+ self.gate: TopKBalancedNoisyGate
1101
+ self.calculator: UniversalCalculator
1102
+
1103
+ def _create_gate(self, **kwargs):
1104
+ self.gate_type = kwargs.get("gate_type", "TopKBalancedNoisyGate")
1105
+
1106
+ if self.gate_type == "TopKBalancedNoisyGate": # noisy gate
1107
+ self.gate = TopKBalancedNoisyGate(
1108
+ self.input_size,
1109
+ self.num_experts,
1110
+ self.num_selects,
1111
+ gate_network=kwargs.get("gate_network", "mlp"),
1112
+ use_softmax=kwargs.get("gate_use_softmax", True),
1113
+ use_balance=kwargs.get("gate_use_balance", True),
1114
+ balance_loss_weight=kwargs.get("gate_balance_loss_weight", 1e-2),
1115
+ add_noise=kwargs.get("gate_add_noise", True),
1116
+ noise_epsilon=kwargs.get("gate_noise_epsilon", 1e-2),
1117
+ )
1118
+ else:
1119
+ raise NotImplementedError
1120
+
1121
+ def _create_calculator(self, experts, **kwargs):
1122
+ self.calculator_type = kwargs.get("calculator_type", "UniversalCalculator")
1123
+
1124
+ if self.calculator_type == "UniversalCalculator": # top K calculator
1125
+ self.calculator = UniversalCalculator(
1126
+ experts,
1127
+ multiply_gate_scores=kwargs.get("multiply_gate_scores", True),
1128
+ score_scale_factor=kwargs.get("score_scale_factor", 1.0),
1129
+ add_weight_norm=kwargs.get("add_weight_norm", False),
1130
+ )
1131
+ else:
1132
+ raise NotImplementedError
1133
+
1134
+ def forward(self, x, attention_mask=None) -> MoEMlpOutput:
1135
+ original_shape = x.shape[:-1]
1136
+ x = x.reshape(-1, self.input_size)
1137
+ flattened_mask = None
1138
+ if attention_mask is not None and len(attention_mask.shape) == 2:
1139
+ flattened_mask = attention_mask.flatten()
1140
+ flattened_shape = flattened_mask.shape
1141
+ x = x[flattened_mask.bool()]
1142
+
1143
+ gate_outputs: dict = self.gate(x)
1144
+ calc_outs: CalculatorOutput = self.calculator(x, **gate_outputs)
1145
+
1146
+ y = calc_outs.hidden_states
1147
+ if flattened_mask is not None:
1148
+ y = torch.zeros(flattened_shape + (self.output_size,), dtype=x.dtype, device=x.device) # (batch_size*seq_len, output_size)
1149
+ y[flattened_mask.bool()] = calc_outs.hidden_states # (non_padding_num, output_size)
1150
+ y = y.reshape(original_shape + (self.output_size,))
1151
+
1152
+ return MoEMlpOutput(
1153
+ hidden_states=y,
1154
+ balance_loss=gate_outputs.get("balance_loss"),
1155
+ num_dropped_tokens=calc_outs.num_dropped_tokens,
1156
+ gate_load=gate_outputs.get("load", torch.tensor(-1)),
1157
+ gate_importance=gate_outputs.get("importance", torch.tensor(-1)),
1158
+ )
1159
+
1160
+ def reset_gate_network(self):
1161
+ self.gate.reset_gate_network()
1162
+
1163
+ def reset_experts(self):
1164
+ self.calculator.reset_experts()
1165
+
1166
+
1167
+ class LinearGLUMoELayer(BaseMoELayer):
1168
+ def __init__(
1169
+ self,
1170
+ input_size,
1171
+ hidden_size,
1172
+ output_size,
1173
+ hidden_act,
1174
+ num_experts,
1175
+ num_selects,
1176
+ size_experts=None,
1177
+ bias=True,
1178
+ **kwargs,
1179
+ ):
1180
+ super(LinearGLUMoELayer, self).__init__()
1181
+ assert num_selects <= num_experts
1182
+ self.input_size = input_size
1183
+ self.hidden_size = hidden_size
1184
+ self.output_size = output_size
1185
+ self.hidden_act = hidden_act
1186
+ self.num_experts = num_experts
1187
+ self.num_selects = num_selects
1188
+ self.size_experts = size_experts
1189
+ self.bias = bias
1190
+
1191
+ experts = LinearGLUExperts(
1192
+ input_size,
1193
+ hidden_size,
1194
+ output_size,
1195
+ hidden_act,
1196
+ num_experts,
1197
+ size_experts=size_experts,
1198
+ bias=bias,
1199
+ )
1200
+
1201
+ self._create_gate(**kwargs)
1202
+ self._create_calculator(experts, **kwargs)
1203
+
1204
+
1205
+ class LlamaMoEDecoderLayer(nn.Module):
1206
+ def __init__(self, config: LlamaMoEConfig, layer_index):
1207
+ super().__init__()
1208
+
1209
+ self.hidden_size = config.hidden_size
1210
+ # self.self_attn = LlamaAttention(config=config)
1211
+ self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_index)
1212
+
1213
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1214
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1215
+
1216
+ gating_config = {
1217
+ # all gates
1218
+ "gate_type": config.gate_type,
1219
+ "gate_network": config.gate_network,
1220
+ "gate_use_softmax": config.gate_use_softmax,
1221
+ "gate_use_balance": config.gate_use_balance,
1222
+ "gate_balance_loss_weight": config.gate_balance_loss_weight,
1223
+ "gate_add_noise": config.gate_add_noise,
1224
+ # TopKBalancedNoisyGate
1225
+ "gate_noise_epsilon": config.gate_noise_epsilon,
1226
+ }
1227
+ calculator_config = {
1228
+ # all calculators
1229
+ "calculator_type": config.calculator_type,
1230
+ "multiply_gate_scores": config.multiply_gate_scores,
1231
+ "score_scale_factor": (
1232
+ config.score_scale_factor[layer_index]
1233
+ if isinstance(config.score_scale_factor, list)
1234
+ else config.score_scale_factor
1235
+ ),
1236
+ "add_weight_norm": config.add_weight_norm,
1237
+ # SwitchDropTokenCalculator
1238
+ "drop_tokens": config.drop_tokens,
1239
+ "dropped_padding": config.dropped_padding,
1240
+ "capacity_factor": config.capacity_factor,
1241
+ }
1242
+
1243
+ self.mlp = LinearGLUMoELayer(
1244
+ input_size=self.hidden_size,
1245
+ hidden_size=config.intermediate_size,
1246
+ output_size=self.hidden_size,
1247
+ hidden_act=config.hidden_act,
1248
+ num_experts=config.num_experts,
1249
+ num_selects=config.num_selects,
1250
+ size_experts=(
1251
+ config.size_experts[layer_index]
1252
+ if config.size_experts is not None
1253
+ else None
1254
+ ),
1255
+ bias=False,
1256
+ **gating_config,
1257
+ **calculator_config,
1258
+ )
1259
+
1260
+ def forward(
1261
+ self,
1262
+ hidden_states,
1263
+ attention_mask=None,
1264
+ position_ids=None,
1265
+ past_key_value=None,
1266
+ output_attentions=False,
1267
+ use_cache=False,
1268
+ ) -> tuple:
1269
+ residual = hidden_states
1270
+ hidden_states = self.input_layernorm(hidden_states)
1271
+
1272
+ # Self Attention
1273
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
1274
+ hidden_states=hidden_states,
1275
+ attention_mask=attention_mask,
1276
+ position_ids=position_ids,
1277
+ past_key_value=past_key_value,
1278
+ output_attentions=output_attentions,
1279
+ use_cache=use_cache,
1280
+ )
1281
+ hidden_states = residual + hidden_states
1282
+
1283
+ # Fully Connected
1284
+ residual = hidden_states
1285
+ hidden_states = self.post_attention_layernorm(hidden_states)
1286
+ mlp_outs: MoEMlpOutput = self.mlp(hidden_states, attention_mask=attention_mask)
1287
+ hidden_states = residual + mlp_outs.hidden_states
1288
+
1289
+ outputs = (
1290
+ hidden_states,
1291
+ mlp_outs.balance_loss,
1292
+ mlp_outs.num_dropped_tokens,
1293
+ mlp_outs.gate_load,
1294
+ mlp_outs.gate_importance,
1295
+ )
1296
+ if output_attentions:
1297
+ outputs += (self_attn_weights,)
1298
+ if use_cache:
1299
+ outputs += (present_key_value,)
1300
+
1301
+ return outputs
1302
+
1303
+
1304
+ class LlamaMoEPreTrainedModel(PreTrainedModel):
1305
+ config_class = LlamaMoEConfig
1306
+ base_model_prefix = "model"
1307
+ supports_gradient_checkpointing = True
1308
+ _no_split_modules = ["LlamaMoEDecoderLayer"]
1309
+ _skip_keys_device_placement = "past_key_values"
1310
+ _supports_flash_attn_2 = True
1311
+
1312
+ def _init_weights(self, module):
1313
+ std = self.config.initializer_range
1314
+ if isinstance(module, nn.Linear):
1315
+ module.weight.data.normal_(mean=0.0, std=std)
1316
+ if module.bias is not None:
1317
+ module.bias.data.zero_()
1318
+ elif isinstance(module, nn.Embedding):
1319
+ module.weight.data.normal_(mean=0.0, std=std)
1320
+ if module.padding_idx is not None:
1321
+ module.weight.data[module.padding_idx].zero_()
1322
+
1323
+
1324
+ class LlamaMoEModel(LlamaMoEPreTrainedModel):
1325
+ def __init__(self, config: LlamaMoEConfig):
1326
+ super().__init__(config)
1327
+ self.padding_idx = config.pad_token_id
1328
+ self.vocab_size = config.vocab_size
1329
+
1330
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1331
+ self.layers = nn.ModuleList(
1332
+ [LlamaMoEDecoderLayer(config, i) for i in range(config.num_hidden_layers)]
1333
+ )
1334
+ self._use_sdpa = config._attn_implementation == "sdpa"
1335
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
1336
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1337
+ self.gradient_checkpointing = False
1338
+ self.post_init()
1339
+
1340
+ def get_input_embeddings(self):
1341
+ return self.embed_tokens
1342
+
1343
+ def set_input_embeddings(self, value):
1344
+ self.embed_tokens = value
1345
+
1346
+ def forward(
1347
+ self,
1348
+ input_ids=None,
1349
+ attention_mask=None,
1350
+ position_ids=None,
1351
+ past_key_values=None,
1352
+ inputs_embeds=None,
1353
+ use_cache=None,
1354
+ output_attentions=None,
1355
+ output_hidden_states=None,
1356
+ return_dict=None,
1357
+ ):
1358
+ output_attentions = (
1359
+ output_attentions
1360
+ if output_attentions is not None
1361
+ else self.config.output_attentions
1362
+ )
1363
+ output_hidden_states = (
1364
+ output_hidden_states
1365
+ if output_hidden_states is not None
1366
+ else self.config.output_hidden_states
1367
+ )
1368
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1369
+
1370
+ return_dict = (
1371
+ return_dict if return_dict is not None else self.config.use_return_dict
1372
+ )
1373
+
1374
+ # retrieve input_ids and inputs_embeds
1375
+ if input_ids is not None and inputs_embeds is not None:
1376
+ raise ValueError(
1377
+ "You cannot specify both decoder_input_ids and decoder_inputs_embeds at"
1378
+ " the same time"
1379
+ )
1380
+ elif input_ids is not None:
1381
+ batch_size, seq_length = input_ids.shape
1382
+ elif inputs_embeds is not None:
1383
+ batch_size, seq_length, _ = inputs_embeds.shape
1384
+ else:
1385
+ raise ValueError(
1386
+ "You have to specify either decoder_input_ids or decoder_inputs_embeds"
1387
+ )
1388
+
1389
+ if self.gradient_checkpointing and self.training:
1390
+ if use_cache:
1391
+ logger.warning_once(
1392
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1393
+ )
1394
+ use_cache = False
1395
+
1396
+ past_key_values_length = 0
1397
+ if use_cache:
1398
+ use_legacy_cache = not isinstance(past_key_values, Cache)
1399
+ if use_legacy_cache:
1400
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1401
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
1402
+
1403
+ if position_ids is None:
1404
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1405
+ position_ids = torch.arange(
1406
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
1407
+ )
1408
+ position_ids = position_ids.unsqueeze(0)
1409
+
1410
+ if inputs_embeds is None:
1411
+ inputs_embeds = self.embed_tokens(input_ids)
1412
+
1413
+ if self._use_flash_attention_2:
1414
+ # 2d mask is passed through the layers
1415
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1416
+ elif self._use_sdpa and not output_attentions:
1417
+ # output_attentions=True can not be supported when using SDPA, and we fall back on
1418
+ # the manual implementation that requires a 4D causal mask in all cases.
1419
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1420
+ attention_mask,
1421
+ (batch_size, seq_length),
1422
+ inputs_embeds,
1423
+ past_key_values_length,
1424
+ )
1425
+ else:
1426
+ # 4d mask is passed through the layers
1427
+ attention_mask = _prepare_4d_causal_attention_mask(
1428
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
1429
+ )
1430
+
1431
+ hidden_states = inputs_embeds
1432
+ balance_loss = 0.0
1433
+
1434
+ # decoder layers
1435
+ all_hidden_states = () if output_hidden_states else None
1436
+ all_self_attns = () if output_attentions else None
1437
+ next_decoder_cache = None
1438
+
1439
+ num_dropped_tokens = ()
1440
+ gate_load = ()
1441
+ gate_importance = ()
1442
+ for idx, decoder_layer in enumerate(self.layers):
1443
+ if output_hidden_states:
1444
+ all_hidden_states += (hidden_states,)
1445
+
1446
+ if self.gradient_checkpointing and self.training:
1447
+ layer_outputs = self._gradient_checkpointing_func(
1448
+ decoder_layer.__call__,
1449
+ hidden_states,
1450
+ attention_mask,
1451
+ position_ids,
1452
+ past_key_values,
1453
+ output_attentions,
1454
+ use_cache,
1455
+ )
1456
+ else:
1457
+ layer_outputs = decoder_layer(
1458
+ hidden_states,
1459
+ attention_mask=attention_mask,
1460
+ position_ids=position_ids,
1461
+ past_key_value=past_key_values,
1462
+ output_attentions=output_attentions,
1463
+ use_cache=use_cache,
1464
+ )
1465
+
1466
+ hidden_states = layer_outputs[0]
1467
+ if layer_outputs[1] is not None:
1468
+ balance_loss += layer_outputs[1]
1469
+
1470
+ if use_cache:
1471
+ next_decoder_cache = layer_outputs[6 if output_attentions else 5]
1472
+
1473
+ if output_attentions:
1474
+ all_self_attns += (layer_outputs[5],)
1475
+
1476
+ num_dropped_tokens += (layer_outputs[2],)
1477
+ gate_load += (layer_outputs[3],)
1478
+ gate_importance += (layer_outputs[4],)
1479
+
1480
+ hidden_states = self.norm(hidden_states)
1481
+
1482
+ # add hidden states from the last decoder layer
1483
+ if output_hidden_states:
1484
+ all_hidden_states += (hidden_states,)
1485
+
1486
+ next_cache = None
1487
+ if use_cache:
1488
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
1489
+ if not return_dict:
1490
+ return tuple(
1491
+ v
1492
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
1493
+ if v is not None
1494
+ )
1495
+ return BaseMoEModelOutputWithPast(
1496
+ last_hidden_state=hidden_states,
1497
+ balance_loss=balance_loss,
1498
+ past_key_values=next_cache,
1499
+ hidden_states=all_hidden_states,
1500
+ attentions=all_self_attns,
1501
+ num_dropped_tokens=num_dropped_tokens,
1502
+ gate_load=gate_load,
1503
+ gate_importance=gate_importance,
1504
+ )
1505
+
1506
+ def reset_gate_network(self):
1507
+ for idx, decoder_layer in enumerate(self.layers):
1508
+ decoder_layer.reset_gate_network()
1509
+
1510
+ def reset_experts(self):
1511
+ for idx, decoder_layer in enumerate(self.layers):
1512
+ decoder_layer.reset_experts()
1513
+
1514
+
1515
+ class LlamaMoEForCausalLM(LlamaMoEPreTrainedModel):
1516
+ _tied_weights_keys = ["lm_head.weight"]
1517
+
1518
+ def __init__(self, config):
1519
+ super().__init__(config)
1520
+ self.model = LlamaMoEModel(config)
1521
+ self.pretraining_tp = config.pretraining_tp
1522
+ self.vocab_size = config.vocab_size
1523
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1524
+
1525
+ # Initialize weights and apply final processing
1526
+ self.post_init()
1527
+
1528
+ def get_input_embeddings(self):
1529
+ return self.model.embed_tokens
1530
+
1531
+ def set_input_embeddings(self, value):
1532
+ self.model.embed_tokens = value
1533
+
1534
+ def get_output_embeddings(self):
1535
+ return self.lm_head
1536
+
1537
+ def set_output_embeddings(self, new_embeddings):
1538
+ self.lm_head = new_embeddings
1539
+
1540
+ def set_decoder(self, decoder):
1541
+ self.model = decoder
1542
+
1543
+ def get_decoder(self):
1544
+ return self.model
1545
+
1546
+ def forward(
1547
+ self,
1548
+ input_ids=None,
1549
+ attention_mask=None,
1550
+ position_ids=None,
1551
+ past_key_values=None,
1552
+ inputs_embeds=None,
1553
+ labels=None,
1554
+ use_cache=None,
1555
+ output_attentions=None,
1556
+ output_hidden_states=None,
1557
+ return_dict=None,
1558
+ **kwargs,
1559
+ ):
1560
+ output_attentions = (
1561
+ output_attentions
1562
+ if output_attentions is not None
1563
+ else self.config.output_attentions
1564
+ )
1565
+ output_hidden_states = (
1566
+ output_hidden_states
1567
+ if output_hidden_states is not None
1568
+ else self.config.output_hidden_states
1569
+ )
1570
+ return_dict = (
1571
+ return_dict if return_dict is not None else self.config.use_return_dict
1572
+ )
1573
+
1574
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1575
+ outputs: BaseMoEModelOutputWithPast = self.model(
1576
+ input_ids=input_ids,
1577
+ attention_mask=attention_mask,
1578
+ position_ids=position_ids,
1579
+ past_key_values=past_key_values,
1580
+ inputs_embeds=inputs_embeds,
1581
+ use_cache=use_cache,
1582
+ output_attentions=output_attentions,
1583
+ output_hidden_states=output_hidden_states,
1584
+ return_dict=return_dict,
1585
+ )
1586
+
1587
+ hidden_states = outputs.last_hidden_state
1588
+ logits = self.lm_head(hidden_states)
1589
+
1590
+ loss = None
1591
+ if labels is not None:
1592
+ # Shift so that tokens < n predict n
1593
+ shift_logits = logits[..., :-1, :].contiguous()
1594
+ shift_labels = labels[..., 1:].contiguous()
1595
+ # Flatten the tokens
1596
+ loss_fct = nn.CrossEntropyLoss()
1597
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1598
+ shift_labels = shift_labels.view(-1)
1599
+ # Enable model parallelism
1600
+ shift_labels = shift_labels.to(shift_logits.device)
1601
+ loss = loss_fct(shift_logits, shift_labels)
1602
+ if outputs.balance_loss is not None and outputs.balance_loss > 0:
1603
+ loss += outputs.balance_loss
1604
+
1605
+ if not return_dict:
1606
+ output = (logits,) + outputs[1:]
1607
+ return (loss,) + output if loss is not None else output
1608
+
1609
+ return MoECausalLMOutputWithPast(
1610
+ loss=loss,
1611
+ logits=logits,
1612
+ past_key_values=outputs.past_key_values,
1613
+ hidden_states=outputs.hidden_states,
1614
+ attentions=outputs.attentions,
1615
+ num_dropped_tokens=outputs.num_dropped_tokens,
1616
+ balance_loss=outputs.balance_loss,
1617
+ gate_load=outputs.gate_load,
1618
+ gate_importance=outputs.gate_importance,
1619
+ )
1620
+
1621
+ def prepare_inputs_for_generation(
1622
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1623
+ ):
1624
+ if past_key_values is not None:
1625
+ if isinstance(past_key_values, Cache):
1626
+ cache_length = past_key_values.get_seq_length()
1627
+ past_length = past_key_values.seen_tokens
1628
+ max_cache_length = past_key_values.get_max_length()
1629
+ else:
1630
+ cache_length = past_length = past_key_values[0][0].shape[2]
1631
+ max_cache_length = None
1632
+
1633
+ # Keep only the unprocessed tokens:
1634
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1635
+ # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
1636
+ # input)
1637
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1638
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1639
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1640
+ # input_ids based on the past_length.
1641
+ elif past_length < input_ids.shape[1]:
1642
+ input_ids = input_ids[:, past_length:]
1643
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1644
+
1645
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1646
+ if (
1647
+ max_cache_length is not None
1648
+ and attention_mask is not None
1649
+ and cache_length + input_ids.shape[1] > max_cache_length
1650
+ ):
1651
+ attention_mask = attention_mask[:, -max_cache_length:]
1652
+
1653
+ position_ids = kwargs.get("position_ids", None)
1654
+ if attention_mask is not None and position_ids is None:
1655
+ # create position_ids on the fly for batch generation
1656
+ position_ids = attention_mask.long().cumsum(-1) - 1
1657
+ position_ids.masked_fill_(attention_mask == 0, 1)
1658
+ if past_key_values:
1659
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1660
+
1661
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1662
+ if inputs_embeds is not None and past_key_values is None:
1663
+ model_inputs = {"inputs_embeds": inputs_embeds}
1664
+ else:
1665
+ model_inputs = {"input_ids": input_ids}
1666
+
1667
+ model_inputs.update(
1668
+ {
1669
+ "position_ids": position_ids,
1670
+ "past_key_values": past_key_values,
1671
+ "use_cache": kwargs.get("use_cache"),
1672
+ "attention_mask": attention_mask,
1673
+ }
1674
+ )
1675
+ return model_inputs
1676
+
1677
+ @staticmethod
1678
+ def _reorder_cache(past_key_values, beam_idx):
1679
+ reordered_past = ()
1680
+ for layer_past in past_key_values:
1681
+ reordered_past += (
1682
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1683
+ )
1684
+ return reordered_past
1685
+
1686
+ def reset_gate_network(self):
1687
+ self.model.reset_gate_network()
1688
+
1689
+ def reset_experts(self):
1690
+ self.model.reset_experts()
special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "<unk>",
17
+ "unk_token": {
18
+ "content": "<unk>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
3
+ size 499723
tokenizer_config.json ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "added_tokens_decoder": {
5
+ "0": {
6
+ "content": "<unk>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "1": {
14
+ "content": "<s>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "2": {
22
+ "content": "</s>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ }
29
+ },
30
+ "bos_token": "<s>",
31
+ "clean_up_tokenization_spaces": false,
32
+ "eos_token": "</s>",
33
+ "legacy": false,
34
+ "model_max_length": 2048,
35
+ "pad_token": "<unk>",
36
+ "padding_side": "right",
37
+ "sp_model_kwargs": {},
38
+ "spaces_between_special_tokens": false,
39
+ "tokenizer_class": "LlamaTokenizer",
40
+ "unk_token": "<unk>",
41
+ "use_default_system_prompt": false,
42
+ "use_fast": true
43
+ }