Spico commited on
Commit
b386a77
1 Parent(s): 6ce71d5

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. comment.txt +7 -0
  2. config.json +373 -0
  3. configuration_llama_moe.py +130 -0
  4. diff.patch +863 -0
  5. generation_config.json +7 -0
  6. model-00001-of-00003.safetensors +3 -0
  7. model-00002-of-00003.safetensors +3 -0
  8. model-00003-of-00003.safetensors +3 -0
  9. model.safetensors.index.json +0 -0
  10. modeling_llama_moe_hf.py +1690 -0
  11. sampling_info/100/load.pdf +0 -0
  12. sampling_info/100/prob_map.pdf +0 -0
  13. sampling_info/100/sim.pdf +0 -0
  14. sampling_info/1000/load.pdf +0 -0
  15. sampling_info/1000/prob_map.pdf +0 -0
  16. sampling_info/1000/sim.pdf +0 -0
  17. sampling_info/1100/load.pdf +0 -0
  18. sampling_info/1100/prob_map.pdf +0 -0
  19. sampling_info/1100/sim.pdf +0 -0
  20. sampling_info/1200/load.pdf +0 -0
  21. sampling_info/1200/prob_map.pdf +0 -0
  22. sampling_info/1200/sim.pdf +0 -0
  23. sampling_info/1300/load.pdf +0 -0
  24. sampling_info/1300/prob_map.pdf +0 -0
  25. sampling_info/1300/sim.pdf +0 -0
  26. sampling_info/1400/load.pdf +0 -0
  27. sampling_info/1400/prob_map.pdf +0 -0
  28. sampling_info/1400/sim.pdf +0 -0
  29. sampling_info/1500/load.pdf +0 -0
  30. sampling_info/1500/prob_map.pdf +0 -0
  31. sampling_info/1500/sim.pdf +0 -0
  32. sampling_info/1600/load.pdf +0 -0
  33. sampling_info/1600/prob_map.pdf +0 -0
  34. sampling_info/1600/sim.pdf +0 -0
  35. sampling_info/1700/load.pdf +0 -0
  36. sampling_info/1700/prob_map.pdf +0 -0
  37. sampling_info/1700/sim.pdf +0 -0
  38. sampling_info/1800/load.pdf +0 -0
  39. sampling_info/1800/prob_map.pdf +0 -0
  40. sampling_info/1800/sim.pdf +0 -0
  41. sampling_info/1900/load.pdf +0 -0
  42. sampling_info/1900/prob_map.pdf +0 -0
  43. sampling_info/1900/sim.pdf +0 -0
  44. sampling_info/200/load.pdf +0 -0
  45. sampling_info/200/prob_map.pdf +0 -0
  46. sampling_info/200/sim.pdf +0 -0
  47. sampling_info/2000/load.pdf +0 -0
  48. sampling_info/2000/prob_map.pdf +0 -0
  49. sampling_info/2000/sim.pdf +0 -0
  50. sampling_info/300/load.pdf +0 -0
comment.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ Job ID: 2498282
2
+
3
+ Git commit: 10e3e0a update alpaca eval gen
4
+
5
+ Git branch: * main
6
+
7
+ Comment: llama_moe_four_mix_freeze_gate_100
config.json ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/mnt/petrelfs/zhutong/llama-moe-models/LLaMA-MoE-v1-3_5B-2_8-new",
3
+ "add_weight_norm": false,
4
+ "architectures": [
5
+ "LlamaMoEForCausalLM"
6
+ ],
7
+ "attention_bias": false,
8
+ "attention_dropout": 0.0,
9
+ "auto_map": {
10
+ "AutoConfig": "configuration_llama_moe.LlamaMoEConfig",
11
+ "AutoModel": "modeling_llama_moe_hf.LlamaMoEModel",
12
+ "AutoModelForCausalLM": "modeling_llama_moe_hf.LlamaMoEForCausalLM"
13
+ },
14
+ "bos_token_id": 1,
15
+ "calculator_type": "UniversalCalculator",
16
+ "capacity_factor": 1.25,
17
+ "drop_tokens": true,
18
+ "dropped_padding": "zero",
19
+ "eos_token_id": 2,
20
+ "gate_add_noise": true,
21
+ "gate_balance_loss_weight": 0.01,
22
+ "gate_network": "mlp",
23
+ "gate_noise_epsilon": 0.01,
24
+ "gate_type": "TopKBalancedNoisyGate",
25
+ "gate_use_balance": true,
26
+ "gate_use_softmax": true,
27
+ "gates": "mlp",
28
+ "hidden_act": "silu",
29
+ "hidden_size": 4096,
30
+ "initializer_range": 0.02,
31
+ "intermediate_size": 11008,
32
+ "max_position_embeddings": 4096,
33
+ "model_type": "llama_moe",
34
+ "multiply_gate_scores": true,
35
+ "num_attention_heads": 32,
36
+ "num_experts": 8,
37
+ "num_hidden_layers": 32,
38
+ "num_key_value_heads": 32,
39
+ "num_selects": 2,
40
+ "pad_token_id": 0,
41
+ "pretraining_tp": 1,
42
+ "rms_norm_eps": 1e-05,
43
+ "rope_scaling": null,
44
+ "rope_theta": 10000.0,
45
+ "score_scale_factor": 4.0,
46
+ "size_experts": [
47
+ [
48
+ 1376,
49
+ 1376,
50
+ 1376,
51
+ 1376,
52
+ 1376,
53
+ 1376,
54
+ 1376,
55
+ 1376
56
+ ],
57
+ [
58
+ 1376,
59
+ 1376,
60
+ 1376,
61
+ 1376,
62
+ 1376,
63
+ 1376,
64
+ 1376,
65
+ 1376
66
+ ],
67
+ [
68
+ 1376,
69
+ 1376,
70
+ 1376,
71
+ 1376,
72
+ 1376,
73
+ 1376,
74
+ 1376,
75
+ 1376
76
+ ],
77
+ [
78
+ 1376,
79
+ 1376,
80
+ 1376,
81
+ 1376,
82
+ 1376,
83
+ 1376,
84
+ 1376,
85
+ 1376
86
+ ],
87
+ [
88
+ 1376,
89
+ 1376,
90
+ 1376,
91
+ 1376,
92
+ 1376,
93
+ 1376,
94
+ 1376,
95
+ 1376
96
+ ],
97
+ [
98
+ 1376,
99
+ 1376,
100
+ 1376,
101
+ 1376,
102
+ 1376,
103
+ 1376,
104
+ 1376,
105
+ 1376
106
+ ],
107
+ [
108
+ 1376,
109
+ 1376,
110
+ 1376,
111
+ 1376,
112
+ 1376,
113
+ 1376,
114
+ 1376,
115
+ 1376
116
+ ],
117
+ [
118
+ 1376,
119
+ 1376,
120
+ 1376,
121
+ 1376,
122
+ 1376,
123
+ 1376,
124
+ 1376,
125
+ 1376
126
+ ],
127
+ [
128
+ 1376,
129
+ 1376,
130
+ 1376,
131
+ 1376,
132
+ 1376,
133
+ 1376,
134
+ 1376,
135
+ 1376
136
+ ],
137
+ [
138
+ 1376,
139
+ 1376,
140
+ 1376,
141
+ 1376,
142
+ 1376,
143
+ 1376,
144
+ 1376,
145
+ 1376
146
+ ],
147
+ [
148
+ 1376,
149
+ 1376,
150
+ 1376,
151
+ 1376,
152
+ 1376,
153
+ 1376,
154
+ 1376,
155
+ 1376
156
+ ],
157
+ [
158
+ 1376,
159
+ 1376,
160
+ 1376,
161
+ 1376,
162
+ 1376,
163
+ 1376,
164
+ 1376,
165
+ 1376
166
+ ],
167
+ [
168
+ 1376,
169
+ 1376,
170
+ 1376,
171
+ 1376,
172
+ 1376,
173
+ 1376,
174
+ 1376,
175
+ 1376
176
+ ],
177
+ [
178
+ 1376,
179
+ 1376,
180
+ 1376,
181
+ 1376,
182
+ 1376,
183
+ 1376,
184
+ 1376,
185
+ 1376
186
+ ],
187
+ [
188
+ 1376,
189
+ 1376,
190
+ 1376,
191
+ 1376,
192
+ 1376,
193
+ 1376,
194
+ 1376,
195
+ 1376
196
+ ],
197
+ [
198
+ 1376,
199
+ 1376,
200
+ 1376,
201
+ 1376,
202
+ 1376,
203
+ 1376,
204
+ 1376,
205
+ 1376
206
+ ],
207
+ [
208
+ 1376,
209
+ 1376,
210
+ 1376,
211
+ 1376,
212
+ 1376,
213
+ 1376,
214
+ 1376,
215
+ 1376
216
+ ],
217
+ [
218
+ 1376,
219
+ 1376,
220
+ 1376,
221
+ 1376,
222
+ 1376,
223
+ 1376,
224
+ 1376,
225
+ 1376
226
+ ],
227
+ [
228
+ 1376,
229
+ 1376,
230
+ 1376,
231
+ 1376,
232
+ 1376,
233
+ 1376,
234
+ 1376,
235
+ 1376
236
+ ],
237
+ [
238
+ 1376,
239
+ 1376,
240
+ 1376,
241
+ 1376,
242
+ 1376,
243
+ 1376,
244
+ 1376,
245
+ 1376
246
+ ],
247
+ [
248
+ 1376,
249
+ 1376,
250
+ 1376,
251
+ 1376,
252
+ 1376,
253
+ 1376,
254
+ 1376,
255
+ 1376
256
+ ],
257
+ [
258
+ 1376,
259
+ 1376,
260
+ 1376,
261
+ 1376,
262
+ 1376,
263
+ 1376,
264
+ 1376,
265
+ 1376
266
+ ],
267
+ [
268
+ 1376,
269
+ 1376,
270
+ 1376,
271
+ 1376,
272
+ 1376,
273
+ 1376,
274
+ 1376,
275
+ 1376
276
+ ],
277
+ [
278
+ 1376,
279
+ 1376,
280
+ 1376,
281
+ 1376,
282
+ 1376,
283
+ 1376,
284
+ 1376,
285
+ 1376
286
+ ],
287
+ [
288
+ 1376,
289
+ 1376,
290
+ 1376,
291
+ 1376,
292
+ 1376,
293
+ 1376,
294
+ 1376,
295
+ 1376
296
+ ],
297
+ [
298
+ 1376,
299
+ 1376,
300
+ 1376,
301
+ 1376,
302
+ 1376,
303
+ 1376,
304
+ 1376,
305
+ 1376
306
+ ],
307
+ [
308
+ 1376,
309
+ 1376,
310
+ 1376,
311
+ 1376,
312
+ 1376,
313
+ 1376,
314
+ 1376,
315
+ 1376
316
+ ],
317
+ [
318
+ 1376,
319
+ 1376,
320
+ 1376,
321
+ 1376,
322
+ 1376,
323
+ 1376,
324
+ 1376,
325
+ 1376
326
+ ],
327
+ [
328
+ 1376,
329
+ 1376,
330
+ 1376,
331
+ 1376,
332
+ 1376,
333
+ 1376,
334
+ 1376,
335
+ 1376
336
+ ],
337
+ [
338
+ 1376,
339
+ 1376,
340
+ 1376,
341
+ 1376,
342
+ 1376,
343
+ 1376,
344
+ 1376,
345
+ 1376
346
+ ],
347
+ [
348
+ 1376,
349
+ 1376,
350
+ 1376,
351
+ 1376,
352
+ 1376,
353
+ 1376,
354
+ 1376,
355
+ 1376
356
+ ],
357
+ [
358
+ 1376,
359
+ 1376,
360
+ 1376,
361
+ 1376,
362
+ 1376,
363
+ 1376,
364
+ 1376,
365
+ 1376
366
+ ]
367
+ ],
368
+ "tie_word_embeddings": false,
369
+ "torch_dtype": "bfloat16",
370
+ "transformers_version": "4.36.2",
371
+ "use_cache": true,
372
+ "vocab_size": 32000
373
+ }
configuration_llama_moe.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+
3
+
4
+ class LlamaMoEConfig(PretrainedConfig):
5
+ model_type = "llama_moe"
6
+ keys_to_ignore_at_inference = ["past_key_values"]
7
+
8
+ def __init__(
9
+ self,
10
+ vocab_size=32000,
11
+ hidden_size=4096,
12
+ intermediate_size=11008,
13
+ num_hidden_layers=32,
14
+ num_attention_heads=32,
15
+ num_key_value_heads=None,
16
+ hidden_act="silu",
17
+ max_position_embeddings=2048,
18
+ initializer_range=0.02,
19
+ rms_norm_eps=1e-6,
20
+ use_cache=True,
21
+ pad_token_id=0,
22
+ bos_token_id=1,
23
+ eos_token_id=2,
24
+ pretraining_tp=1,
25
+ tie_word_embeddings=False,
26
+ rope_theta=10000.0,
27
+ rope_scaling=None,
28
+ attention_bias=False,
29
+ attention_dropout=0.0,
30
+ # -------- moe expert configs --------
31
+ num_experts=16,
32
+ num_selects=4,
33
+ size_experts=None,
34
+ # -------- moe gate configs --------
35
+ gate_type="TopKBalancedNoisyGate",
36
+ gate_network="mlp",
37
+ gate_use_softmax=True,
38
+ gate_use_balance=True,
39
+ gate_balance_loss_weight=1e-2,
40
+ gate_add_noise=True,
41
+ # TopKBalancedNoisyGate
42
+ gate_noise_epsilon=1e-2,
43
+ # -------- moe calculator configs --------
44
+ calculator_type="UniversalCalculator",
45
+ multiply_gate_scores=True,
46
+ score_scale_factor=1.0,
47
+ add_weight_norm=False,
48
+ # SwitchDropTokenCalculator
49
+ drop_tokens=True,
50
+ dropped_padding="zero",
51
+ capacity_factor=1.25,
52
+ **kwargs,
53
+ ):
54
+ self.vocab_size = vocab_size
55
+ self.max_position_embeddings = max_position_embeddings
56
+ self.hidden_size = hidden_size
57
+ self.intermediate_size = intermediate_size
58
+ self.num_hidden_layers = num_hidden_layers
59
+ self.num_attention_heads = num_attention_heads
60
+ self.hidden_act = hidden_act
61
+ self.initializer_range = initializer_range
62
+ self.rms_norm_eps = rms_norm_eps
63
+ self.pretraining_tp = pretraining_tp
64
+ self.use_cache = use_cache
65
+ self.rope_theta = rope_theta
66
+ self.rope_scaling = rope_scaling
67
+ self._rope_scaling_validation()
68
+ self.attention_bias = attention_bias
69
+ self.attention_dropout = attention_dropout
70
+
71
+ self.num_experts = num_experts
72
+ self.num_selects = num_selects
73
+ self.size_experts = size_experts
74
+
75
+ self.gate_type = gate_type
76
+ self.gate_network = gate_network
77
+ self.gate_use_softmax = gate_use_softmax
78
+ self.gate_use_balance = gate_use_balance
79
+ self.gate_balance_loss_weight = gate_balance_loss_weight
80
+ self.gate_add_noise = gate_add_noise
81
+ self.gate_noise_epsilon = gate_noise_epsilon
82
+
83
+ self.calculator_type = calculator_type
84
+ self.multiply_gate_scores = multiply_gate_scores
85
+ self.score_scale_factor = score_scale_factor
86
+ self.add_weight_norm = add_weight_norm
87
+ self.drop_tokens = drop_tokens
88
+ self.dropped_padding = dropped_padding
89
+ self.capacity_factor = capacity_factor
90
+
91
+ # for backward compatibility
92
+ if num_key_value_heads is None:
93
+ num_key_value_heads = num_attention_heads
94
+
95
+ self.num_key_value_heads = num_key_value_heads
96
+
97
+ super().__init__(
98
+ pad_token_id=pad_token_id,
99
+ bos_token_id=bos_token_id,
100
+ eos_token_id=eos_token_id,
101
+ tie_word_embeddings=tie_word_embeddings,
102
+ **kwargs,
103
+ )
104
+
105
+ def _rope_scaling_validation(self):
106
+ """
107
+ Validate the `rope_scaling` configuration.
108
+ """
109
+ if self.rope_scaling is None:
110
+ return
111
+
112
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
113
+ raise ValueError(
114
+ "`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, "
115
+ f"got {self.rope_scaling}"
116
+ )
117
+ rope_scaling_type = self.rope_scaling.get("type", None)
118
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
119
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
120
+ raise ValueError(
121
+ f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
122
+ )
123
+ if (
124
+ rope_scaling_factor is None
125
+ or not isinstance(rope_scaling_factor, float)
126
+ or rope_scaling_factor <= 1.0
127
+ ):
128
+ raise ValueError(
129
+ f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}"
130
+ )
diff.patch ADDED
@@ -0,0 +1,863 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diff --git a/.gitignore b/.gitignore
2
+ index c243024..8c28ce3 100644
3
+ --- a/.gitignore
4
+ +++ b/.gitignore
5
+ @@ -175,6 +175,7 @@ debug.py
6
+ wandb/
7
+ nohup.out
8
+ lm-evaluation-harness/
9
+ +bigcode-evaluation-harness/
10
+ results/**/*.json
11
+ results/**/*.jsonl
12
+ results/**/*.db
13
+ diff --git a/README.md b/README.md
14
+ index 8813a32..b276a78 100644
15
+ --- a/README.md
16
+ +++ b/README.md
17
+ @@ -26,6 +26,11 @@ bash scripts/data.sh
18
+ git clone https://github.com/EleutherAI/lm-evaluation-harness.git
19
+ cd lm-evaluation-harness
20
+ pip install -e .
21
+ +# commit: 9cfa52b
22
+ +git clone https://github.com/bigcode-project/bigcode-evaluation-harness.git
23
+ +cd bigcode-evaluation-harness
24
+ +# change `pyext==0.5` in `bigcode-evaluation-harness/requirements.txt`, ref: https://github.com/bigcode-project/bigcode-evaluation-harness/pull/181
25
+ +pip install -e .
26
+ ```
27
+
28
+ ## 📃 TODO
29
+ diff --git a/scripts/eval.sh b/scripts/eval.sh
30
+ deleted file mode 100644
31
+ index 4f41b37..0000000
32
+ --- a/scripts/eval.sh
33
+ +++ /dev/null
34
+ @@ -1,96 +0,0 @@
35
+ -# nohup srun -p MoE --gres gpu:1 bash scripts/eval.sh all /mnt/petrelfs/share_data/quxiaoye/models/Sheared-LLaMA-2.7B True results/Sheared-LLaMA-2.7B 1>logs/eval-all-Sheared-LLaMA-2.7B.log 2>&1 &
36
+ -
37
+ -mmlu() {
38
+ - # MMLU: https://github.com/princeton-nlp/LLM-Shearing/blob/20ebd2645a8ff5fa65874e1347f9891b80e01805/icl_eval/run_eval.sh#L18
39
+ - MODEL=$1
40
+ - TRUST_REMOTE_CODE=$2
41
+ - RESULT_DIR=$3
42
+ - mkdir -p $RESULT_DIR
43
+ -
44
+ - lm_eval \
45
+ - --model hf \
46
+ - --model_args pretrained=$MODEL,trust_remote_code=$TRUST_REMOTE_CODE \
47
+ - --tasks mmlu_computer_security,mmlu_high_school_chemistry,mmlu_philosophy,mmlu_elementary_mathematics,mmlu_prehistory,mmlu_formal_logic,mmlu_high_school_mathematics,mmlu_econometrics,mmlu_moral_scenarios,mmlu_college_mathematics,mmlu_high_school_government_and_politics,mmlu_us_foreign_policy,mmlu_high_school_world_history,mmlu_conceptual_physics,mmlu_college_medicine,mmlu_international_law,mmlu_abstract_algebra,mmlu_logical_fallacies,mmlu_machine_learning,mmlu_medical_genetics,mmlu_public_relations,mmlu_college_biology,mmlu_marketing,mmlu_electrical_engineering,mmlu_anatomy,mmlu_high_school_us_history,mmlu_high_school_biology,mmlu_miscellaneous,mmlu_high_school_psychology,mmlu_sociology,mmlu_business_ethics,mmlu_high_school_geography,mmlu_human_aging,mmlu_high_school_statistics,mmlu_moral_disputes,mmlu_professional_psychology,mmlu_global_facts,mmlu_college_physics,mmlu_nutrition,mmlu_high_school_macroeconomics,mmlu_world_religions,mmlu_professional_medicine,mmlu_high_school_computer_science,mmlu_college_chemistry,mmlu_human_sexuality,mmlu_high_school_microeconomics,mmlu_astronomy,mmlu_professional_accounting,mmlu_high_school_european_history,mmlu_jurisprudence,mmlu_professional_law,mmlu_high_school_physics,mmlu_virology,mmlu_management,mmlu_college_computer_science,mmlu_clinical_knowledge,mmlu_security_studies \
48
+ - --num_fewshot 5 \
49
+ - --device cuda:0 \
50
+ - --batch_size auto \
51
+ - --verbosity DEBUG \
52
+ - --output_path $RESULT_DIR/mmlu.json
53
+ -}
54
+ -
55
+ -bbh() {
56
+ - # Big Bench Hard (BBH): https://arxiv.org/pdf/2210.09261.pdf
57
+ - MODEL=$1
58
+ - TRUST_REMOTE_CODE=$2
59
+ - RESULT_DIR=$3
60
+ - mkdir -p $RESULT_DIR
61
+ -
62
+ - lm_eval \
63
+ - --log_samples \
64
+ - --model hf \
65
+ - --model_args pretrained=$MODEL,trust_remote_code=$TRUST_REMOTE_CODE \
66
+ - --tasks bbh_fewshot_boolean_expressions,bbh_fewshot_causal_judgement,bbh_fewshot_date_understanding,bbh_fewshot_disambiguation_qa,bbh_fewshot_dyck_languages,bbh_fewshot_formal_fallacies,bbh_fewshot_geometric_shapes,bbh_fewshot_hyperbaton,bbh_fewshot_logical_deduction_five_objects,bbh_fewshot_logical_deduction_seven_objects,bbh_fewshot_logical_deduction_three_objects,bbh_fewshot_movie_recommendation,bbh_fewshot_multistep_arithmetic_two,bbh_fewshot_navigate,bbh_fewshot_object_counting,bbh_fewshot_penguins_in_a_table,bbh_fewshot_reasoning_about_colored_objects,bbh_fewshot_ruin_names,bbh_fewshot_salient_translation_error_detection,bbh_fewshot_snarks,bbh_fewshot_sports_understanding,bbh_fewshot_temporal_sequences,bbh_fewshot_tracking_shuffled_objects_five_objects,bbh_fewshot_tracking_shuffled_objects_seven_objects,bbh_fewshot_tracking_shuffled_objects_three_objects,bbh_fewshot_web_of_lies,bbh_fewshot_word_sorting \
67
+ - --device cuda:0 \
68
+ - --batch_size auto \
69
+ - --verbosity DEBUG \
70
+ - --output_path $RESULT_DIR/bbh.json
71
+ -}
72
+ -
73
+ -reasoning() {
74
+ - MODEL=$1
75
+ - TRUST_REMOTE_CODE=$2
76
+ - RESULT_DIR=$3
77
+ - mkdir -p $RESULT_DIR
78
+ -
79
+ - lm_eval \
80
+ - --log_samples \
81
+ - --model hf \
82
+ - --model_args pretrained=$MODEL,trust_remote_code=$TRUST_REMOTE_CODE \
83
+ - --tasks gsm8k_cot \
84
+ - --device cuda:0 \
85
+ - --batch_size auto \
86
+ - --verbosity DEBUG \
87
+ - --output_path $RESULT_DIR/reasoning.json
88
+ -}
89
+ -
90
+ -qa() {
91
+ - MODEL=$1
92
+ - TRUST_REMOTE_CODE=$2
93
+ - RESULT_DIR=$3
94
+ - mkdir -p $RESULT_DIR
95
+ -
96
+ - lm_eval \
97
+ - --log_samples \
98
+ - --model hf \
99
+ - --model_args pretrained=$MODEL,trust_remote_code=$TRUST_REMOTE_CODE \
100
+ - --tasks arc_easy,arc_challenge,boolq \
101
+ - --num_fewshot 0 \
102
+ - --device cuda:0 \
103
+ - --batch_size auto \
104
+ - --verbosity DEBUG \
105
+ - --output_path $RESULT_DIR/qa.json
106
+ -}
107
+ -
108
+ -EVAL_TASK=$1
109
+ -shift 1
110
+ -start=$(date +%s)
111
+ -case $EVAL_TASK in
112
+ - mmlu)
113
+ - mmlu $* ;;
114
+ - bbh)
115
+ - bbh $* ;;
116
+ - reasoning)
117
+ - reasoning $* ;;
118
+ - qa)
119
+ - qa $* ;;
120
+ - all)
121
+ - mmlu $*
122
+ - bbh $*
123
+ - reasoning $*
124
+ - qa $*
125
+ - ;;
126
+ - *)
127
+ - echo "$EVAL_TASK not recognized!";;
128
+ -esac
129
+ -end=$(date +%s)
130
+ -echo "Elapsed Time: $(($end-$start)) seconds"
131
+ diff --git a/scripts/four_mix/freeze_gate.sh b/scripts/four_mix/freeze_gate.sh
132
+ index d94d78c..70afb8e 100644
133
+ --- a/scripts/four_mix/freeze_gate.sh
134
+ +++ b/scripts/four_mix/freeze_gate.sh
135
+ @@ -83,8 +83,11 @@ num_gpus=4
136
+
137
+ python -m src.eval.gen_mt_ans \
138
+ --model-path $output_dir \
139
+ - --model-id $task_name \
140
+ - --num-gpus-total $num_gpus
141
+ + --model-id $task_name
142
+ +
143
+ + python -m src.eval.gen_alpaca_eval_ans \
144
+ + --model-path $output_dir \
145
+ + --model-id $task_name
146
+ }
147
+
148
+ # nohup srun -p MoE --ntasks-per-node=1 --cpus-per-task=16 --mem=128G --nodes=1 --gres=gpu:4 bash "/mnt/petrelfs/zhutong/adaptive-sft-for-moe/scripts/one_data_steps_dynamic.sh" "llama_moe_orca_epochs_cluster_4" "auto" "/mnt/petrelfs/zhutong/llama-moe-models/LLaMA-MoE-v1-3_5B-2_8-new" "data/open_orca_clustered/4" "data/open_orca_clustered_eval/4" 1>logs/llama_moe_orca_cluster_4_dynamic.log 2>&1 &
149
+ diff --git a/scripts/gen_mt_bench_ans.sh b/scripts/gen_mt_bench_ans.sh
150
+ deleted file mode 100644
151
+ index f251644..0000000
152
+ --- a/scripts/gen_mt_bench_ans.sh
153
+ +++ /dev/null
154
+ @@ -1,32 +0,0 @@
155
+ -#!/usr/bin/bash
156
+ -
157
+ -#SBATCH --job-name=moe_gen
158
+ -#SBATCH --output=logs/%x-%j.log
159
+ -#SBATCH --error=logs/%x-%j.log
160
+ -
161
+ -#SBATCH --partition=MoE
162
+ -#SBATCH --ntasks-per-node=1
163
+ -#SBATCH --cpus-per-task=16
164
+ -#SBATCH --mem=64G
165
+ -
166
+ -#SBATCH --nodes=1
167
+ -#SBATCH --gres=gpu:1
168
+ -#SBATCH --quotatype=auto
169
+ -
170
+ -{
171
+ - # python -m fastchat.llm_judge.gen_model_answer \
172
+ - # --model-path outputs/sheared_llama_sharegpt/moe_sft-2411306 \
173
+ - # --model-id sheared_llama_sharegpt
174
+ -
175
+ - # python -m fastchat.llm_judge.gen_model_answer \
176
+ - # --model-path outputs/sheared_llama_uniform_mix/moe_sft-2421072 \
177
+ - # --model-id sheared_llama_uniform_mix
178
+ -
179
+ - bash scripts/cp_model_files.sh outputs/llama_moe/moe_sft-2409782
180
+ - python -m fastchat.llm_judge.gen_model_answer \
181
+ - --model-path outputs/llama_moe/moe_sft-2409782 \
182
+ - --model-id llama_moe_uniform_mix
183
+ -}
184
+ -
185
+ -# nohup srun -p MoE -n1 -N1 --gres=gpu:1 --quotatype spot python -m fastchat.llm_judge.gen_model_answer --model-path outputs/sheared_llama_sharegpt/moe_sft-2411306 --model-id sheared_llama_sharegpt 1>logs/mt_bench_gen_sheared_llama_sharegpt.log 2>&1 &
186
+ -# nohup srun -p MoE -n1 -N1 --gres=gpu:1 --quotatype spot python -m fastchat.llm_judge.gen_model_answer --model-path /mnt/petrelfs/zhutong/adaptive-sft-for-moe/outputs/llama_moe_sharegpt/moe_sft-2411309 --model-id llama_moe_sharegpt 1>logs/mt_bench_gen_llama_moe_sharegpt.log 2>&1 &
187
+ diff --git a/scripts/multi.sh b/scripts/multi.sh
188
+ index bcd83b8..e399761 100644
189
+ --- a/scripts/multi.sh
190
+ +++ b/scripts/multi.sh
191
+ @@ -100,5 +100,8 @@ nohup srun -p MoE --ntasks-per-node=1 --cpus-per-task=16 --mem=128G --nodes=1 --
192
+ nohup srun -p MoE --gres gpu:1 python -m src.eval.gen_mt_ans --model-path /mnt/petrelfs/zhutong/adaptive-sft-for-moe/outputs/len2048/llama_moe_four_mix_uniform/bash-2485396 --model-id llama_moe_four_mix_uniform 1>logs/gen_mt_ans-llama_moe_four_mix_uniform.log 2>&1 &
193
+ nohup srun -p MoE --gres gpu:1 python -m src.eval.gen_mt_ans --model-path /mnt/petrelfs/zhutong/adaptive-sft-for-moe/outputs/len2048/sheared_four_mix_uniform/bash-2485397 --model-id sheared_four_mix_uniform 1>logs/gen_mt_ans-sheared_four_mix_uniform.log 2>&1 &
194
+
195
+ -nohup srun -p MoE --gres gpu:1 python -m src.eval.get_alpaca_eval_ans --model-path /mnt/petrelfs/zhutong/adaptive-sft-for-moe/outputs/len2048/llama_moe_four_mix_uniform/bash-2485396 --model-id llama_moe_four_mix_uniform 1>logs/gen_alpaca_eval-llama_moe_four_mix_uniform.log 2>&1 &
196
+ -nohup srun -p MoE --gres gpu:1 python -m src.eval.get_alpaca_eval_ans --model-path /mnt/petrelfs/zhutong/adaptive-sft-for-moe/outputs/len2048/sheared_four_mix_uniform/bash-2485397 --model-id sheared_four_mix_uniform 1>logs/gen_alpaca_eval-sheared_four_mix_uniform.log 2>&1 &
197
+ +nohup srun -p MoE --gres gpu:1 python -m src.eval.gen_alpaca_eval_ans --model-path /mnt/petrelfs/zhutong/adaptive-sft-for-moe/outputs/len2048/llama_moe_four_mix_uniform/bash-2485396 --model-id llama_moe_four_mix_uniform 1>logs/gen_alpaca_eval-llama_moe_four_mix_uniform.log 2>&1 &
198
+ +nohup srun -p MoE --gres gpu:1 python -m src.eval.gen_alpaca_eval_ans --model-path /mnt/petrelfs/zhutong/adaptive-sft-for-moe/outputs/len2048/sheared_four_mix_uniform/bash-2485397 --model-id sheared_four_mix_uniform 1>logs/gen_alpaca_eval-sheared_four_mix_uniform.log 2>&1 &
199
+ +
200
+ +nohup srun -p MoE --gres gpu:1 bash scripts/eval/eval.sh reasoning /mnt/petrelfs/zhutong/adaptive-sft-for-moe/outputs/len2048_dynamic_remove_padding_tokens/llama_moe_four_mix_wo_pad_wo_gate_noise/moe_sft-2492650 True results/llama_moe_four_mix_wo_pad_wo_gate_noise 1>logs/eval-reasoning-llama_moe_four_mix_wo_pad_wo_gate_noise.log 2>&1 &
201
+ +nohup srun -p MoE --gres gpu:1 bash scripts/eval/eval.sh reasoning /mnt/petrelfs/zhutong/adaptive-sft-for-moe/outputs/len2048_dynamic_remove_padding_tokens/llama_moe_four_mix_wo_pad/moe_sft-2491633 True results/llama_moe_four_mix_wo_pad 1>logs/eval-reasoning-llama_moe_four_mix_wo_pad.log 2>&1 &
202
+ diff --git a/src/callbacks.py b/src/callbacks.py
203
+ index a750f69..e9d0c04 100644
204
+ --- a/src/callbacks.py
205
+ +++ b/src/callbacks.py
206
+ @@ -6,6 +6,7 @@ import torch
207
+ import numpy as np
208
+ from loguru import logger
209
+ from transformers.trainer_callback import TrainerCallback, TrainerState, TrainerControl
210
+ +from transformers.utils import is_flash_attn_2_available
211
+
212
+ from src.utils.config import TrainingArguments
213
+ from src.utils.io import append_jsonlines
214
+ @@ -22,6 +23,7 @@ class AdaptiveSamplingCallback(TrainerCallback):
215
+ criterion: Optional[Literal["min", "max", "mean"]] = "mean",
216
+ sim_type: Optional[Literal["cos", "l2"]] = "cos",
217
+ ):
218
+ + assert is_flash_attn_2_available(), "Make sure you have flash-attn installed"
219
+ self.criterion = criterion
220
+ self.sim_type = sim_type
221
+ self.prob_map = {}
222
+ @@ -74,8 +76,8 @@ class AdaptiveSamplingCallback(TrainerCallback):
223
+ cls,
224
+ ori_weights: np.ndarray,
225
+ delta: np.ndarray,
226
+ - eta: float = 1.0,
227
+ - c: float = 1e-4,
228
+ + eta: float = 10.0,
229
+ + c: float = 5e-2,
230
+ ) -> np.ndarray:
231
+ def _softmax(vec: np.ndarray) -> np.ndarray:
232
+ exps = np.exp(vec - np.max(vec))
233
+ diff --git a/src/core/train.py b/src/core/train.py
234
+ index 2be5558..9b1f694 100644
235
+ --- a/src/core/train.py
236
+ +++ b/src/core/train.py
237
+ @@ -7,13 +7,12 @@ from loguru import logger
238
+ from src.utils.config import ModelArguments, DataArguments, TrainingArguments
239
+ from src.data import (
240
+ SubDirWeightedPackedJsonlDataset,
241
+ - get_uniform_sampling_ratio,
242
+ fault_tolerance_data_collator,
243
+ CachedJsonlDataset,
244
+ get_cached_datasets_from_dir,
245
+ )
246
+ from src.utils.io import trainer_save_model_safe
247
+ -from src.models import LlamaMoEForCausalLM, LlamaMoEConfig
248
+ +from src.models import LlamaMoEForCausalLM, LlamaMoEConfig, DeepseekConfig, DeepseekForCausalLM
249
+ from src.trainer import GateLoadRecordingTrainer
250
+ from src.callbacks import AdaptiveSamplingCallback
251
+
252
+ @@ -36,6 +35,9 @@ def get_model_and_tokenizer(
253
+ elif model_type == "llama_moe":
254
+ ConfigClass = LlamaMoEConfig
255
+ ModelClass = LlamaMoEForCausalLM
256
+ + elif model_type == "deepseek":
257
+ + ConfigClass = DeepseekConfig
258
+ + ModelClass = DeepseekForCausalLM
259
+ else:
260
+ raise ValueError(f"Unknown model type: {model_type}")
261
+
262
+ @@ -54,6 +56,21 @@ def get_model_and_tokenizer(
263
+ config.update(additional_config)
264
+ logger.info("Config ready")
265
+
266
+ + tokenizer = transformers.AutoTokenizer.from_pretrained(
267
+ + model_name_or_path,
268
+ + cache_dir=cache_dir,
269
+ + model_max_length=model_max_length,
270
+ + padding_side=padding_side,
271
+ + use_fast=False,
272
+ + trust_remote_code=trust_remote_code,
273
+ + )
274
+ + if tokenizer.pad_token is None:
275
+ + if tokenizer.unk_token is not None:
276
+ + tokenizer.pad_token = tokenizer.unk_token
277
+ + else:
278
+ + tokenizer.pad_token = tokenizer.eos_token
279
+ + logger.info(f"tokenizer ready, pad_token: {tokenizer.pad_token}")
280
+ +
281
+ # Load model and tokenizer
282
+ model = ModelClass.from_pretrained(
283
+ model_name_or_path,
284
+ @@ -65,18 +82,6 @@ def get_model_and_tokenizer(
285
+ )
286
+ logger.info("model ready")
287
+
288
+ - tokenizer = transformers.AutoTokenizer.from_pretrained(
289
+ - model_name_or_path,
290
+ - cache_dir=cache_dir,
291
+ - model_max_length=model_max_length,
292
+ - padding_side=padding_side,
293
+ - use_fast=False,
294
+ - trust_remote_code=trust_remote_code,
295
+ - )
296
+ - if tokenizer.pad_token != tokenizer.unk_token:
297
+ - tokenizer.pad_token = tokenizer.unk_token
298
+ - logger.info("tokenizer ready")
299
+ -
300
+ return model, tokenizer
301
+
302
+
303
+ @@ -117,7 +122,9 @@ def train():
304
+ train_dataset = SubDirWeightedPackedJsonlDataset(
305
+ data_args.dataset_dir_or_path,
306
+ tokenizer,
307
+ - prob_map=get_uniform_sampling_ratio(data_args.dataset_dir_or_path),
308
+ + # prob_map=get_uniform_sampling_ratio(data_args.dataset_dir_or_path),
309
+ + # prob_map={"code": 0.25119094959816823, "math": 0.2674581878910902, "orca": 0.243050776175138, "sharegpt": 0.23830008633560357},
310
+ + prob_map=data_args.prob_map,
311
+ seed=training_args.seed,
312
+ )
313
+ elif datapath.is_file():
314
+ diff --git a/src/data.py b/src/data.py
315
+ index d783a21..a1a8ff7 100644
316
+ --- a/src/data.py
317
+ +++ b/src/data.py
318
+ @@ -20,6 +20,7 @@ def preprocess(
319
+ instances,
320
+ tokenizer: transformers.PreTrainedTokenizer,
321
+ ) -> Dict:
322
+ + tokenizer_legacy = getattr(tokenizer, "legacy", None)
323
+ conv = Conversation()
324
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
325
+
326
+ @@ -72,7 +73,7 @@ def preprocess(
327
+ # "-2" is hardcoded for the Llama tokenizer to make the offset correct.
328
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
329
+
330
+ - if i != 0 and not tokenizer.legacy:
331
+ + if i != 0 and not tokenizer_legacy:
332
+ # The legacy and non-legacy modes handle special tokens differently
333
+ instruction_len -= 1
334
+
335
+ @@ -80,7 +81,7 @@ def preprocess(
336
+ target[cur_len : cur_len + instruction_len] = IGNORE_TOKEN_ID
337
+ cur_len += turn_len
338
+
339
+ - if i != 0 and not tokenizer.legacy:
340
+ + if i != 0 and not tokenizer_legacy:
341
+ # The legacy and non-legacy modes handle special tokens differently
342
+ cur_len -= 1
343
+
344
+ diff --git a/src/eval/get_alpaca_eval_ans.py b/src/eval/get_alpaca_eval_ans.py
345
+ deleted file mode 100644
346
+ index 1ff3e5e..0000000
347
+ --- a/src/eval/get_alpaca_eval_ans.py
348
+ +++ /dev/null
349
+ @@ -1,113 +0,0 @@
350
+ -import argparse
351
+ -from pathlib import Path
352
+ -
353
+ -import torch
354
+ -import datasets
355
+ -from tqdm import tqdm
356
+ -
357
+ -from src.core.train import get_model_and_tokenizer
358
+ -from src.utils.conversation import Conversation
359
+ -from src.utils.io import dump_json
360
+ -
361
+ -
362
+ -@torch.inference_mode()
363
+ -def run_eval(model_path, model_id, max_new_tokens):
364
+ - model, tokenizer = get_model_and_tokenizer(
365
+ - "auto",
366
+ - model_path,
367
+ - torch_dtype=torch.bfloat16,
368
+ - trust_remote_code=True,
369
+ - )
370
+ - model.cuda()
371
+ - model.eval()
372
+ -
373
+ - conv = Conversation()
374
+ - outputs = []
375
+ - eval_set = datasets.load_dataset("tatsu-lab/alpaca_eval", "alpaca_eval")["eval"]
376
+ - for example in tqdm(eval_set, desc="Eval"):
377
+ - conv.append_message(conv.roles[0], example["instruction"])
378
+ - conv.append_message(conv.roles[1], None)
379
+ - prompt = conv.get_prompt()
380
+ - input_ids = tokenizer([prompt], return_tensors="pt").input_ids
381
+ - conv.clear_msg()
382
+ - # generate here is a placeholder for your models generations
383
+ - output_ids = model.generate(
384
+ - input_ids.cuda(),
385
+ - do_sample=False,
386
+ - temperature=0.0,
387
+ - max_new_tokens=max_new_tokens,
388
+ - )
389
+ - if model.config.is_encoder_decoder:
390
+ - output_ids = output_ids[0]
391
+ - else:
392
+ - output_ids = output_ids[0][len(input_ids[0]) :] # noqa: E203
393
+ - # be consistent with the template's stop_token_ids
394
+ - if conv.stop_token_ids:
395
+ - stop_token_ids_index = [
396
+ - i
397
+ - for i, id in enumerate(output_ids)
398
+ - if id in conv.stop_token_ids
399
+ - ]
400
+ - if len(stop_token_ids_index) > 0:
401
+ - output_ids = output_ids[: stop_token_ids_index[0]]
402
+ -
403
+ - output = tokenizer.decode(
404
+ - output_ids,
405
+ - spaces_between_special_tokens=False,
406
+ - )
407
+ - if conv.stop_str and isinstance(conv.stop_str, list):
408
+ - stop_str_indices = sorted(
409
+ - [
410
+ - output.find(stop_str)
411
+ - for stop_str in conv.stop_str
412
+ - if output.find(stop_str) > 0
413
+ - ]
414
+ - )
415
+ - if len(stop_str_indices) > 0:
416
+ - output = output[: stop_str_indices[0]]
417
+ - elif conv.stop_str and output.find(conv.stop_str) > 0:
418
+ - output = output[: output.find(conv.stop_str)]
419
+ -
420
+ - for special_token in tokenizer.special_tokens_map.values():
421
+ - if isinstance(special_token, list):
422
+ - for special_tok in special_token:
423
+ - output = output.replace(special_tok, "")
424
+ - else:
425
+ - output = output.replace(special_token, "")
426
+ - output = output.strip()
427
+ -
428
+ - if conv.name == "xgen" and output.startswith("Assistant:"):
429
+ - output = output.replace("Assistant:", "", 1).strip()
430
+ -
431
+ - example["output"] = output
432
+ - outputs.append(example)
433
+ -
434
+ - outpath = Path("results/alpaca_eval") / f"{model_id}.json"
435
+ - dump_json(outputs, outpath, indent=2)
436
+ -
437
+ -
438
+ -if __name__ == "__main__":
439
+ - parser = argparse.ArgumentParser()
440
+ - parser.add_argument(
441
+ - "--model-path",
442
+ - type=str,
443
+ - required=True,
444
+ - help="The path to the weights. This can be a local folder or a Hugging Face repo ID.",
445
+ - )
446
+ - parser.add_argument(
447
+ - "--model-id", type=str, required=True, help="A custom name for the model."
448
+ - )
449
+ - parser.add_argument(
450
+ - "--max-new-token",
451
+ - type=int,
452
+ - default=1024,
453
+ - help="The maximum number of new generated tokens.",
454
+ - )
455
+ -
456
+ - args = parser.parse_args()
457
+ -
458
+ - run_eval(
459
+ - model_path=args.model_path,
460
+ - model_id=args.model_id,
461
+ - max_new_tokens=args.max_new_token,
462
+ - )
463
+ diff --git a/src/eval/show.py b/src/eval/show.py
464
+ index d500054..ea0c210 100644
465
+ --- a/src/eval/show.py
466
+ +++ b/src/eval/show.py
467
+ @@ -55,13 +55,13 @@ def collect_results(result_dir: str, verbose: bool = True) -> dict:
468
+ avg = sum(vals) / len(vals)
469
+ tot_vals.append(avg)
470
+ if verbose:
471
+ - logger.info(f"task: {name}, num: {len(tasks.split(','))}, avg: {avg:.3%}")
472
+ + logger.info(f"task: {name}, num: {len(tasks.split(','))}, avg: {100 * avg:.3f} %")
473
+
474
+ if len(tot_vals) == 0:
475
+ tot_avg = 0.0
476
+ else:
477
+ tot_avg = sum(tot_vals) / len(tot_vals)
478
+ - logger.info(f"total avg: {tot_avg:.3%}")
479
+ + logger.info(f"total avg: {100 * tot_avg:.3f} %")
480
+
481
+
482
+ if __name__ == "__main__":
483
+ diff --git a/src/models/deepseek/modeling_deepseek.py b/src/models/deepseek/modeling_deepseek.py
484
+ index 1dae56e..20498b2 100644
485
+ --- a/src/models/deepseek/modeling_deepseek.py
486
+ +++ b/src/models/deepseek/modeling_deepseek.py
487
+ @@ -20,6 +20,7 @@
488
+ """ PyTorch DeepSeek model."""
489
+ import math
490
+ import warnings
491
+ +from dataclasses import dataclass
492
+ from typing import List, Optional, Tuple, Union
493
+
494
+ import torch
495
+ @@ -297,7 +298,7 @@ class DeepseekMLP(nn.Module):
496
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
497
+ self.act_fn = ACT2FN[config.hidden_act]
498
+
499
+ - def forward(self, x):
500
+ + def forward(self, x, **kwargs):
501
+ if self.config.pretraining_tp > 1:
502
+ slice = self.intermediate_size // self.config.pretraining_tp
503
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
504
+ @@ -328,7 +329,9 @@ class DeepseekMLP(nn.Module):
505
+ else:
506
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
507
+
508
+ - return down_proj
509
+ + bsz, seq_len, _ = x.shape
510
+ + load = torch.zeros(bsz * seq_len, self.config.n_routed_experts)
511
+ + return down_proj, load
512
+
513
+
514
+ class MoEGate(nn.Module):
515
+ @@ -356,7 +359,10 @@ class MoEGate(nn.Module):
516
+ init.kaiming_uniform_(self.weight, a=math.sqrt(5))
517
+
518
+ def forward(self, hidden_states):
519
+ - bsz, seq_len, h = hidden_states.shape
520
+ + if len(hidden_states.shape) == 2:
521
+ + bsz, h = hidden_states.shape
522
+ + else:
523
+ + bsz, seq_len, h = hidden_states.shape
524
+ ### compute gating score
525
+ hidden_states = hidden_states.view(-1, h)
526
+ logits = F.linear(hidden_states, self.weight, None)
527
+ @@ -404,7 +410,10 @@ class MoEGate(nn.Module):
528
+ aux_loss = (Pi * fi).sum() * self.alpha
529
+ else:
530
+ aux_loss = None
531
+ - return topk_idx, topk_weight, aux_loss
532
+ + _zeros = torch.zeros_like(logits)
533
+ + _scores_filtered = _zeros.scatter(dim=1, index=topk_idx, src=topk_weight)
534
+ + load = (_scores_filtered > 0).sum(0)
535
+ + return topk_idx, topk_weight, aux_loss, load
536
+
537
+
538
+ class AddAuxiliaryLoss(torch.autograd.Function):
539
+ @@ -450,10 +459,19 @@ class DeepseekMoE(nn.Module):
540
+ config=config, intermediate_size=intermediate_size
541
+ )
542
+
543
+ - def forward(self, hidden_states):
544
+ + def forward(self, hidden_states, attention_mask=None):
545
+ + bsz, seq_len, hsz = hidden_states.shape
546
+ + hidden_states = hidden_states.reshape(-1, hsz)
547
+ + flattened_mask = None
548
+ + flattened_shape = None
549
+ + if attention_mask is not None and len(attention_mask.shape) == 2:
550
+ + flattened_mask = attention_mask.flatten()
551
+ + flattened_shape = flattened_mask.shape
552
+ + hidden_states = hidden_states[flattened_mask.bool()]
553
+ +
554
+ identity = hidden_states
555
+ orig_shape = hidden_states.shape
556
+ - topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
557
+ + topk_idx, topk_weight, aux_loss, load = self.gate(hidden_states)
558
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
559
+ flat_topk_idx = topk_idx.view(-1)
560
+ if self.training:
561
+ @@ -472,7 +490,15 @@ class DeepseekMoE(nn.Module):
562
+ ).view(*orig_shape)
563
+ if self.config.n_shared_experts is not None:
564
+ y = y + self.shared_experts(identity)
565
+ - return y
566
+ +
567
+ + if flattened_mask is not None:
568
+ + _y = torch.zeros(flattened_shape + (hsz,), dtype=y.dtype, device=y.device)
569
+ + _y[flattened_mask.bool()] = y
570
+ + y = _y
571
+ +
572
+ + y = y.reshape(bsz, seq_len, hsz)
573
+ +
574
+ + return y, load
575
+
576
+ @torch.no_grad()
577
+ def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
578
+ @@ -1163,7 +1189,7 @@ class DeepseekDecoderLayer(nn.Module):
579
+ # Fully Connected
580
+ residual = hidden_states
581
+ hidden_states = self.post_attention_layernorm(hidden_states)
582
+ - hidden_states = self.mlp(hidden_states)
583
+ + hidden_states, load = self.mlp(hidden_states, attention_mask=attention_mask)
584
+ hidden_states = residual + hidden_states
585
+
586
+ outputs = (hidden_states,)
587
+ @@ -1174,6 +1200,8 @@ class DeepseekDecoderLayer(nn.Module):
588
+ if use_cache:
589
+ outputs += (present_key_value,)
590
+
591
+ + outputs += (load,)
592
+ +
593
+ return outputs
594
+
595
+
596
+ @@ -1220,6 +1248,11 @@ class DeepseekPreTrainedModel(PreTrainedModel):
597
+ module.weight.data[module.padding_idx].zero_()
598
+
599
+
600
+ +@dataclass
601
+ +class BaseMoEModelOutputWithPast(BaseModelOutputWithPast):
602
+ + gate_load: Optional[torch.Tensor] = None
603
+ +
604
+ +
605
+ Deepseek_INPUTS_DOCSTRING = r"""
606
+ Args:
607
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
608
+ @@ -1429,6 +1462,7 @@ class DeepseekModel(DeepseekPreTrainedModel):
609
+ # decoder layers
610
+ all_hidden_states = () if output_hidden_states else None
611
+ all_self_attns = () if output_attentions else None
612
+ + gate_load = ()
613
+ next_decoder_cache = None
614
+
615
+ for decoder_layer in self.layers:
616
+ @@ -1463,6 +1497,8 @@ class DeepseekModel(DeepseekPreTrainedModel):
617
+ if output_attentions:
618
+ all_self_attns += (layer_outputs[1],)
619
+
620
+ + gate_load += (layer_outputs[-1],)
621
+ +
622
+ hidden_states = self.norm(hidden_states)
623
+
624
+ # add hidden states from the last decoder layer
625
+ @@ -1482,14 +1518,20 @@ class DeepseekModel(DeepseekPreTrainedModel):
626
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
627
+ if v is not None
628
+ )
629
+ - return BaseModelOutputWithPast(
630
+ + return BaseMoEModelOutputWithPast(
631
+ last_hidden_state=hidden_states,
632
+ past_key_values=next_cache,
633
+ hidden_states=all_hidden_states,
634
+ attentions=all_self_attns,
635
+ + gate_load=gate_load,
636
+ )
637
+
638
+
639
+ +@dataclass
640
+ +class MoECausalLMOutputWithPast(CausalLMOutputWithPast):
641
+ + gate_load: Optional[torch.Tensor] = None
642
+ +
643
+ +
644
+ class DeepseekForCausalLM(DeepseekPreTrainedModel):
645
+ _tied_weights_keys = ["lm_head.weight"]
646
+
647
+ @@ -1620,12 +1662,13 @@ class DeepseekForCausalLM(DeepseekPreTrainedModel):
648
+ output = (logits,) + outputs[1:]
649
+ return (loss,) + output if loss is not None else output
650
+
651
+ - return CausalLMOutputWithPast(
652
+ + return MoECausalLMOutputWithPast(
653
+ loss=loss,
654
+ logits=logits,
655
+ past_key_values=outputs.past_key_values,
656
+ hidden_states=outputs.hidden_states,
657
+ attentions=outputs.attentions,
658
+ + gate_load=outputs.gate_load,
659
+ )
660
+
661
+ def prepare_inputs_for_generation(
662
+ diff --git a/src/utils/config.py b/src/utils/config.py
663
+ index 3ea5283..d4060d9 100644
664
+ --- a/src/utils/config.py
665
+ +++ b/src/utils/config.py
666
+ @@ -6,6 +6,7 @@ import torch
667
+ import transformers
668
+
669
+ from src.utils.io import load_json
670
+ +from src.data import get_uniform_sampling_ratio
671
+
672
+
673
+ @dataclass
674
+ @@ -33,7 +34,9 @@ class ModelArguments:
675
+ )
676
+ attn_impl: str = field(
677
+ default="flash_attention_2",
678
+ - metadata={"help": "attention implementation, choice from [eager, flash_attention_2, sdpa] (default: `flash_attention_2`)"}
679
+ + metadata={
680
+ + "help": "attention implementation, choice from [eager, flash_attention_2, sdpa] (default: `flash_attention_2`)"
681
+ + },
682
+ )
683
+
684
+ def __post_init__(self):
685
+ @@ -56,6 +59,18 @@ class DataArguments:
686
+ default="data/merged",
687
+ metadata={"help": "Path to dataset directory or a single jsonl file"},
688
+ )
689
+ + prob_map: str = field(
690
+ + default=None,
691
+ + metadata={"help": "Path to the probability map file"},
692
+ + )
693
+ +
694
+ + def __post_init__(self):
695
+ + if self.prob_map is not None:
696
+ + if not pathlib.Path(self.prob_map).exists():
697
+ + raise ValueError(f"Probability map file {self.prob_map} not found")
698
+ + self.prob_map = load_json(self.prob_map)
699
+ + else:
700
+ + self.prob_map = get_uniform_sampling_ratio(self.dataset_dir_or_path)
701
+
702
+
703
+ @dataclass
704
+ @@ -70,9 +85,7 @@ class TrainingArguments(transformers.TrainingArguments):
705
+ )
706
+ max_eval_steps_per_type: int = field(
707
+ default=10,
708
+ - metadata={
709
+ - "help": "Maximum number of steps to perform during evaluation."
710
+ - },
711
+ + metadata={"help": "Maximum number of steps to perform during evaluation."},
712
+ )
713
+ dynamic_sampling_sim_type: Literal["cos", "l2"] = field(
714
+ default="l2",
715
+ @@ -88,7 +101,5 @@ class TrainingArguments(transformers.TrainingArguments):
716
+ )
717
+ freeze_gate: bool = field(
718
+ default=False,
719
+ - metadata={
720
+ - "help": "Whether to freeze the gate during training."
721
+ - },
722
+ + metadata={"help": "Whether to freeze the gate during training."},
723
+ )
724
+ diff --git a/src/utils/visualization.py b/src/utils/visualization.py
725
+ index 794f6c8..02bd236 100644
726
+ --- a/src/utils/visualization.py
727
+ +++ b/src/utils/visualization.py
728
+ @@ -180,6 +180,86 @@ def gate_load_stats(model_dir, data_dir, result_dir, update_strategy: str = "cos
729
+ )
730
+
731
+
732
+ +def sampling_info_stats(filepath: str, data_type: str, output_dir: str):
733
+ + from pathlib import Path
734
+ + import numpy as np
735
+ + from src.utils.io import load_jsonlines
736
+ +
737
+ + Path(output_dir).mkdir(exist_ok=True, parents=True)
738
+ +
739
+ + data = load_jsonlines(filepath)
740
+ + step2data = {ins["step"]: ins for ins in data}
741
+ +
742
+ + data_types = sorted(data[0]["old_prob_map"].keys())
743
+ + data_type_idx = data_types.index(data_type)
744
+ +
745
+ + probs = []
746
+ + loads = []
747
+ + sims = []
748
+ + steps = sorted(step2data.keys())
749
+ + for step in steps:
750
+ + ins = step2data[step]
751
+ + probs.append(ins["old_prob_map"][data_type])
752
+ + loads.append(ins["name2load"][data_type])
753
+ + sims.append(ins["sim"][data_type_idx])
754
+ +
755
+ + # probs
756
+ + fig = plt.figure()
757
+ + ax = fig.add_subplot(111)
758
+ + ax.plot(steps, probs)
759
+ + ax.set_title(f"Sampling Probability of {data_type}")
760
+ + ax.set_xlabel("step")
761
+ + fig.savefig(f"{output_dir}/prob-{data_type}.png")
762
+ +
763
+ + # loads
764
+ + def cv_square(data):
765
+ + return np.var(data, axis=1) / (np.mean(data, axis=1)**2 + 1e-10)
766
+ +
767
+ + fig = plt.figure()
768
+ + ax = fig.add_subplot(111)
769
+ + ax.plot(steps, cv_square(loads))
770
+ + ax.set_title(f"cv(load)^2 of {data_type}")
771
+ + ax.set_xlabel("step")
772
+ + fig.savefig(f"{output_dir}/load_cv-{data_type}.png")
773
+ +
774
+ + # sims
775
+ + fig = plt.figure()
776
+ + ax = fig.add_subplot(111)
777
+ + ax.plot(steps, np.mean(sims, axis=1))
778
+ + ax.set_title(f"Mean Similarities with {data_type}")
779
+ + ax.set_xlabel("step")
780
+ + fig.savefig(f"{output_dir}/sim-{data_type}.png")
781
+ +
782
+ +
783
+ +def test_sampling_convergence():
784
+ + from collections import defaultdict
785
+ + from src.callbacks import AdaptiveSamplingCallback
786
+ +
787
+ + # freeze gate
788
+ + name2load = {"code": [0.1359794776119403, 0.1333115671641791, 0.12858208955223882, 0.10330223880597016, 0.12544776119402984, 0.12625932835820897, 0.12761194029850748, 0.11950559701492537], "orca": [0.1509941502743006, 0.11721425756978752, 0.1232988815809414, 0.12714439426545024, 0.11256554420634679, 0.14008274482465977, 0.11819552632376563, 0.11050450095474797], "math": [0.15956486572028086, 0.10727138452881943, 0.11506675888262392, 0.10958069091633744, 0.11805010139847842, 0.11915200393871546, 0.13648938539627462, 0.13482480921846976], "sharegpt": [0.15337086599959998, 0.11428233411553493, 0.12873151621889287, 0.1177436980734424, 0.11538123789498336, 0.13793986642403783, 0.12419686111124664, 0.10835362016226212]} # fmt: skip
789
+ + # # dynamic
790
+ + # name2load = {"code": [0.14031716417910448, 0.1310634328358209, 0.12651119402985075, 0.10993470149253731, 0.12196828358208955, 0.12552238805970148, 0.12791977611940297, 0.11676305970149255], "orca": [0.15106234655836084, 0.11803640166095838, 0.12349968175067437, 0.12884551268450883, 0.11344072985178673, 0.1383778377231534, 0.11733170672566907, 0.1094057830448883], "math": [0.16001617686708006, 0.10756444371505268, 0.11391210568886491, 0.114803005615014, 0.11676650216277679, 0.1177863481308685, 0.13630182751708533, 0.13284959030325763], "sharegpt": [0.15440024978412215, 0.113654214863131, 0.12914741653941664, 0.12104040941178769, 0.11470799162832905, 0.13593110446537907, 0.12316259873058931, 0.10795601457724527]} # fmt: skip
791
+ + names = sorted(name2load.keys())
792
+ + callback = AdaptiveSamplingCallback()
793
+ + callback.prob_map = {"code": 0.25, "math": 0.25, "orca": 0.25, "sharegpt": 0.25}
794
+ + name2probs = defaultdict(list)
795
+ + for _ in range(100):
796
+ + for name in names:
797
+ + name2probs[name].append(callback.prob_map[name])
798
+ + new_name2prob, _ = callback._update_prob_map(name2load)
799
+ + callback.prob_map = new_name2prob
800
+ + print(f"final prob_map: {callback.prob_map}")
801
+ +
802
+ + fig = plt.figure()
803
+ + ax = fig.add_subplot(111)
804
+ + for name in names:
805
+ + ax.plot(name2probs[name], label=name)
806
+ + ax.legend()
807
+ + ax.set_title("Sampling Probability")
808
+ + ax.set_xlabel("step")
809
+ + fig.savefig("results/sampling_convergence.png")
810
+ +
811
+ +
812
+ if __name__ == "__main__":
813
+ # gate_load_stats(
814
+ # "/mnt/petrelfs/zhutong/llama-moe-models/LLaMA-MoE-v1-3_5B-2_8-new",
815
+ @@ -195,12 +275,12 @@ if __name__ == "__main__":
816
+ # "results/gate_load_vis_llama_moe_2_8_orca_4clusters",
817
+ # )
818
+
819
+ - gate_load_stats(
820
+ - "/mnt/petrelfs/zhutong/llama-moe-models/LLaMA-MoE-v1-3_5B-2_8-new",
821
+ - "data/four_types_mix/dev",
822
+ - "results/debug",
823
+ - update_strategy="l2",
824
+ - )
825
+ + # gate_load_stats(
826
+ + # "/mnt/petrelfs/zhutong/llama-moe-models/LLaMA-MoE-v1-3_5B-2_8-new",
827
+ + # "data/four_types_mix/dev",
828
+ + # "results/debug",
829
+ + # update_strategy="l2",
830
+ + # )
831
+
832
+ # gate_load_stats(
833
+ # "/mnt/petrelfs/zhutong/llama-moe-models/LLaMA-MoE-v1-3_5B-2_8-new",
834
+ @@ -227,3 +307,29 @@ if __name__ == "__main__":
835
+ # "results/gate_load_vis_llama_moe_2_8_four_types_mix_l2",
836
+ # update_strategy="l2"
837
+ # )
838
+ +
839
+ + # sampling_info_stats(
840
+ + # "/mnt/petrelfs/zhutong/adaptive-sft-for-moe/outputs/len2048_dynamic_remove_padding_tokens/llama_moe_four_mix_wo_pad_freeze_gate/moe_sft-2491632/sampling_info/data.jsonl",
841
+ + # "code",
842
+ + # "results/sampling_info/llama_moe_four_mix_wo_pad_freeze_gate/code",
843
+ + # )
844
+ +
845
+ + # sampling_info_stats(
846
+ + # "/mnt/petrelfs/zhutong/adaptive-sft-for-moe/outputs/len2048_dynamic_remove_padding_tokens/llama_moe_four_mix_wo_pad/moe_sft-2491633/sampling_info/data.jsonl",
847
+ + # "code",
848
+ + # "results/sampling_info/llama_moe_four_mix_wo_pad/code",
849
+ + # )
850
+ +
851
+ + # sampling_info_stats(
852
+ + # "/mnt/petrelfs/zhutong/adaptive-sft-for-moe/outputs/len2048_dynamic_remove_padding_tokens/llama_moe_four_mix_wo_pad_freeze_gate_wo_gate_noise/moe_sft-2493315/sampling_info/data.jsonl",
853
+ + # "code",
854
+ + # "results/sampling_info/llama_moe_four_mix_wo_pad_freeze_gate_wo_gate_noise/code",
855
+ + # )
856
+ +
857
+ + # sampling_info_stats(
858
+ + # "outputs/len2048_dynamic_remove_padding_tokens/llama_moe_four_mix_wo_pad_wo_gate_noise/moe_sft-2492650/sampling_info/data.jsonl",
859
+ + # "code",
860
+ + # "results/sampling_info/llama_moe_four_mix_wo_pad_wo_gate_noise/code",
861
+ + # )
862
+ +
863
+ + test_sampling_convergence()
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "pad_token_id": 0,
6
+ "transformers_version": "4.36.2"
7
+ }
model-00001-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8919f505f53749e1c46511cd975e9da2c91fbcd8105ad30bc26ea2bb5fec3f38
3
+ size 4996976432
model-00002-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:61cb496bd075daa995cf0341587193b2e7a4d5805b4aa561bff4013b1861afff
3
+ size 4982823704
model-00003-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:21fffb1ada83903f9906325e0244222f88a5a97fdc3ab778e424f940e2d07974
3
+ size 3501371152
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_llama_moe_hf.py ADDED
@@ -0,0 +1,1690 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+ from dataclasses import dataclass
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import torch.utils.checkpoint
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torch.distributions.normal import Normal
11
+ from transformers.modeling_outputs import (
12
+ CausalLMOutputWithPast,
13
+ )
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.activations import ACT2FN
16
+ from transformers.utils import ModelOutput, logging
17
+ from transformers.cache_utils import Cache, DynamicCache
18
+ from transformers.modeling_attn_mask_utils import (
19
+ AttentionMaskConverter,
20
+ _prepare_4d_attention_mask,
21
+ _prepare_4d_causal_attention_mask,
22
+ _prepare_4d_causal_attention_mask_for_sdpa,
23
+ )
24
+ from transformers.utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10
25
+
26
+ from .configuration_llama_moe import LlamaMoEConfig
27
+
28
+
29
+ if is_flash_attn_2_available():
30
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
31
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
32
+
33
+
34
+ def _get_unpad_data(attention_mask):
35
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
36
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
37
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
38
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
39
+ return (
40
+ indices,
41
+ cu_seqlens,
42
+ max_seqlen_in_batch,
43
+ )
44
+
45
+
46
+ logger = logging.get_logger(__name__)
47
+
48
+ _CONFIG_FOR_DOC = "LlamaMoEConfig"
49
+
50
+
51
+ @dataclass
52
+ class CalculatorOutput(ModelOutput):
53
+ hidden_states: Optional[torch.FloatTensor] = None
54
+ num_dropped_tokens: Optional[int] = None
55
+
56
+
57
+ @dataclass
58
+ class BaseMoEModelOutputWithPast(ModelOutput):
59
+ """
60
+ Args:
61
+ num_dropped_tokens: layer idx to the number of dropped tokens
62
+ """
63
+
64
+ last_hidden_state: torch.FloatTensor = None
65
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
66
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
67
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
68
+ balance_loss: Optional[float] = None
69
+ num_dropped_tokens: Optional[Tuple[torch.Tensor]] = None
70
+ gate_load: Optional[Tuple[list]] = None
71
+ gate_importance: Optional[Tuple[list]] = None
72
+
73
+
74
+ @dataclass
75
+ class MoECausalLMOutputWithPast(CausalLMOutputWithPast):
76
+ balance_loss: Optional[float] = None
77
+ num_dropped_tokens: Optional[Tuple[int]] = None
78
+ gate_load: Optional[Tuple[list[torch.Tensor]]] = None
79
+ gate_importance: Optional[Tuple[list[torch.Tensor]]] = None
80
+
81
+
82
+ @dataclass
83
+ class MoEMlpOutput(ModelOutput):
84
+ hidden_states: Optional[torch.FloatTensor] = None
85
+ balance_loss: Optional[torch.FloatTensor] = None
86
+ num_dropped_tokens: Optional[int] = None
87
+ gate_load: Optional[list] = None
88
+ gate_importance: Optional[list] = None
89
+
90
+
91
+ def _make_causal_mask(
92
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
93
+ ):
94
+ """
95
+ Make causal mask used for bi-directional self-attention.
96
+ """
97
+ bsz, tgt_len = input_ids_shape
98
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
99
+ mask_cond = torch.arange(mask.size(-1), device=device)
100
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
101
+ mask = mask.to(dtype)
102
+
103
+ if past_key_values_length > 0:
104
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
105
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
106
+
107
+
108
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
109
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
110
+ """
111
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
112
+ """
113
+ bsz, src_len = mask.size()
114
+ tgt_len = tgt_len if tgt_len is not None else src_len
115
+
116
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
117
+
118
+ inverted_mask = 1.0 - expanded_mask
119
+
120
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
121
+
122
+
123
+ class LlamaRMSNorm(nn.Module):
124
+ def __init__(self, hidden_size, eps=1e-6):
125
+ """
126
+ LlamaRMSNorm is equivalent to T5LayerNorm
127
+ """
128
+ super().__init__()
129
+ self.weight = nn.Parameter(torch.ones(hidden_size))
130
+ self.variance_epsilon = eps
131
+
132
+ def forward(self, hidden_states):
133
+ input_dtype = hidden_states.dtype
134
+ hidden_states = hidden_states.to(torch.float32)
135
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
136
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
137
+ return self.weight * hidden_states.to(input_dtype)
138
+
139
+
140
+ class LlamaRotaryEmbedding(torch.nn.Module):
141
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
142
+ super().__init__()
143
+
144
+ self.dim = dim
145
+ self.max_position_embeddings = max_position_embeddings
146
+ self.base = base
147
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
148
+ self.register_buffer("inv_freq", inv_freq)
149
+
150
+ # Build here to make `torch.jit.trace` work.
151
+ self._set_cos_sin_cache(
152
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
153
+ )
154
+
155
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
156
+ self.max_seq_len_cached = seq_len
157
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
158
+
159
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
160
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
161
+ emb = torch.cat((freqs, freqs), dim=-1)
162
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
163
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
164
+
165
+ def forward(self, x, seq_len=None):
166
+ # x: [bs, num_attention_heads, seq_len, head_size]
167
+ if seq_len > self.max_seq_len_cached:
168
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
169
+
170
+ return (
171
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
172
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
173
+ )
174
+
175
+
176
+ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
177
+ """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
178
+
179
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
180
+ self.scaling_factor = scaling_factor
181
+ super().__init__(dim, max_position_embeddings, base, device)
182
+
183
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
184
+ self.max_seq_len_cached = seq_len
185
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
186
+ t = t / self.scaling_factor
187
+
188
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
189
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
190
+ emb = torch.cat((freqs, freqs), dim=-1)
191
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
192
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
193
+
194
+
195
+ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
196
+ """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
197
+
198
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
199
+ self.scaling_factor = scaling_factor
200
+ super().__init__(dim, max_position_embeddings, base, device)
201
+
202
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
203
+ self.max_seq_len_cached = seq_len
204
+
205
+ if seq_len > self.max_position_embeddings:
206
+ base = self.base * (
207
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
208
+ ) ** (self.dim / (self.dim - 2))
209
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
210
+ self.register_buffer("inv_freq", inv_freq)
211
+
212
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
213
+
214
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
215
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
216
+ emb = torch.cat((freqs, freqs), dim=-1)
217
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
218
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
219
+
220
+
221
+ def rotate_half(x):
222
+ """Rotates half the hidden dims of the input."""
223
+ x1 = x[..., : x.shape[-1] // 2]
224
+ x2 = x[..., x.shape[-1] // 2 :]
225
+ return torch.cat((-x2, x1), dim=-1)
226
+
227
+
228
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
229
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
230
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
231
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
232
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
233
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
234
+ q_embed = (q * cos) + (rotate_half(q) * sin)
235
+ k_embed = (k * cos) + (rotate_half(k) * sin)
236
+ return q_embed, k_embed
237
+
238
+
239
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
240
+ """
241
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
242
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
243
+ """
244
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
245
+ if n_rep == 1:
246
+ return hidden_states
247
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
248
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
249
+
250
+
251
+ class LlamaAttention(nn.Module):
252
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
253
+
254
+ def __init__(self, config: LlamaMoEConfig, layer_idx: Optional[int] = None):
255
+ super().__init__()
256
+ self.config = config
257
+ self.layer_idx = layer_idx
258
+ if layer_idx is None:
259
+ logger.warning_once(
260
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
261
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
262
+ "when creating this class."
263
+ )
264
+
265
+ self.attention_dropout = config.attention_dropout
266
+ self.hidden_size = config.hidden_size
267
+ self.num_heads = config.num_attention_heads
268
+ self.head_dim = self.hidden_size // self.num_heads
269
+ self.num_key_value_heads = config.num_key_value_heads
270
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
271
+ self.max_position_embeddings = config.max_position_embeddings
272
+ self.rope_theta = config.rope_theta
273
+ self.is_causal = True
274
+
275
+ if (self.head_dim * self.num_heads) != self.hidden_size:
276
+ raise ValueError(
277
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
278
+ f" and `num_heads`: {self.num_heads})."
279
+ )
280
+
281
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
282
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
283
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
284
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
285
+ self._init_rope()
286
+
287
+ def _init_rope(self):
288
+ if self.config.rope_scaling is None:
289
+ self.rotary_emb = LlamaRotaryEmbedding(
290
+ self.head_dim,
291
+ max_position_embeddings=self.max_position_embeddings,
292
+ base=self.rope_theta,
293
+ )
294
+ else:
295
+ scaling_type = self.config.rope_scaling["type"]
296
+ scaling_factor = self.config.rope_scaling["factor"]
297
+ if scaling_type == "linear":
298
+ self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
299
+ self.head_dim,
300
+ max_position_embeddings=self.max_position_embeddings,
301
+ scaling_factor=scaling_factor,
302
+ base=self.rope_theta,
303
+ )
304
+ elif scaling_type == "dynamic":
305
+ self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
306
+ self.head_dim,
307
+ max_position_embeddings=self.max_position_embeddings,
308
+ scaling_factor=scaling_factor,
309
+ base=self.rope_theta,
310
+ )
311
+ else:
312
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
313
+
314
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
315
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
316
+
317
+ def forward(
318
+ self,
319
+ hidden_states: torch.Tensor,
320
+ attention_mask: Optional[torch.Tensor] = None,
321
+ position_ids: Optional[torch.LongTensor] = None,
322
+ past_key_value: Optional[Cache] = None,
323
+ output_attentions: bool = False,
324
+ use_cache: bool = False,
325
+ **kwargs,
326
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
327
+ if "padding_mask" in kwargs:
328
+ warnings.warn(
329
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
330
+ )
331
+
332
+ bsz, q_len, _ = hidden_states.size()
333
+
334
+ if self.config.pretraining_tp > 1:
335
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
336
+ query_slices = self.q_proj.weight.split(
337
+ (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
338
+ )
339
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
340
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
341
+
342
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
343
+ query_states = torch.cat(query_states, dim=-1)
344
+
345
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
346
+ key_states = torch.cat(key_states, dim=-1)
347
+
348
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
349
+ value_states = torch.cat(value_states, dim=-1)
350
+
351
+ else:
352
+ query_states = self.q_proj(hidden_states)
353
+ key_states = self.k_proj(hidden_states)
354
+ value_states = self.v_proj(hidden_states)
355
+
356
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
357
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
358
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
359
+
360
+ kv_seq_len = key_states.shape[-2]
361
+ if past_key_value is not None:
362
+ if self.layer_idx is None:
363
+ raise ValueError(
364
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
365
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
366
+ "with a layer index."
367
+ )
368
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
369
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
370
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
371
+
372
+ if past_key_value is not None:
373
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
374
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
375
+
376
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
377
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
378
+
379
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
380
+
381
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
382
+ raise ValueError(
383
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
384
+ f" {attn_weights.size()}"
385
+ )
386
+
387
+ if attention_mask is not None:
388
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
389
+ raise ValueError(
390
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
391
+ )
392
+ attn_weights = attn_weights + attention_mask
393
+
394
+ # upcast attention to fp32
395
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
396
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
397
+ attn_output = torch.matmul(attn_weights, value_states)
398
+
399
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
400
+ raise ValueError(
401
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
402
+ f" {attn_output.size()}"
403
+ )
404
+
405
+ attn_output = attn_output.transpose(1, 2).contiguous()
406
+
407
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
408
+
409
+ if self.config.pretraining_tp > 1:
410
+ attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
411
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
412
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
413
+ else:
414
+ attn_output = self.o_proj(attn_output)
415
+
416
+ if not output_attentions:
417
+ attn_weights = None
418
+
419
+ return attn_output, attn_weights, past_key_value
420
+
421
+
422
+ class LlamaFlashAttention2(LlamaAttention):
423
+ """
424
+ Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
425
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
426
+ flash attention and deal with padding tokens in case the input contains any of them.
427
+ """
428
+
429
+ def __init__(self, *args, **kwargs):
430
+ super().__init__(*args, **kwargs)
431
+
432
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
433
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
434
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
435
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
436
+
437
+ def forward(
438
+ self,
439
+ hidden_states: torch.Tensor,
440
+ attention_mask: Optional[torch.LongTensor] = None,
441
+ position_ids: Optional[torch.LongTensor] = None,
442
+ past_key_value: Optional[Cache] = None,
443
+ output_attentions: bool = False,
444
+ use_cache: bool = False,
445
+ **kwargs,
446
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
447
+ # LlamaFlashAttention2 attention does not support output_attentions
448
+ if "padding_mask" in kwargs:
449
+ warnings.warn(
450
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
451
+ )
452
+
453
+ # overwrite attention_mask with padding_mask
454
+ attention_mask = kwargs.pop("padding_mask")
455
+
456
+ output_attentions = False
457
+
458
+ bsz, q_len, _ = hidden_states.size()
459
+
460
+ query_states = self.q_proj(hidden_states)
461
+ key_states = self.k_proj(hidden_states)
462
+ value_states = self.v_proj(hidden_states)
463
+
464
+ # Flash attention requires the input to have the shape
465
+ # batch_size x seq_length x head_dim x hidden_dim
466
+ # therefore we just need to keep the original shape
467
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
468
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
469
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
470
+
471
+ kv_seq_len = key_states.shape[-2]
472
+ if past_key_value is not None:
473
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
474
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
475
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
476
+
477
+ if past_key_value is not None:
478
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
479
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
480
+
481
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
482
+ # to be able to avoid many of these transpose/reshape/view.
483
+ query_states = query_states.transpose(1, 2)
484
+ key_states = key_states.transpose(1, 2)
485
+ value_states = value_states.transpose(1, 2)
486
+
487
+ dropout_rate = self.attention_dropout if self.training else 0.0
488
+
489
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
490
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
491
+ # cast them back in the correct dtype just to be sure everything works as expected.
492
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
493
+ # in fp32. (LlamaRMSNorm handles it correctly)
494
+
495
+ input_dtype = query_states.dtype
496
+ if input_dtype == torch.float32:
497
+ if torch.is_autocast_enabled():
498
+ target_dtype = torch.get_autocast_gpu_dtype()
499
+ # Handle the case where the model is quantized
500
+ elif hasattr(self.config, "_pre_quantization_dtype"):
501
+ target_dtype = self.config._pre_quantization_dtype
502
+ else:
503
+ target_dtype = self.q_proj.weight.dtype
504
+
505
+ logger.warning_once(
506
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
507
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
508
+ f" {target_dtype}."
509
+ )
510
+
511
+ query_states = query_states.to(target_dtype)
512
+ key_states = key_states.to(target_dtype)
513
+ value_states = value_states.to(target_dtype)
514
+
515
+ attn_output = self._flash_attention_forward(
516
+ query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
517
+ )
518
+
519
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
520
+ attn_output = self.o_proj(attn_output)
521
+
522
+ if not output_attentions:
523
+ attn_weights = None
524
+
525
+ return attn_output, attn_weights, past_key_value
526
+
527
+ def _flash_attention_forward(
528
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
529
+ ):
530
+ """
531
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
532
+ first unpad the input, then computes the attention scores and pad the final attention scores.
533
+
534
+ Args:
535
+ query_states (`torch.Tensor`):
536
+ Input query states to be passed to Flash Attention API
537
+ key_states (`torch.Tensor`):
538
+ Input key states to be passed to Flash Attention API
539
+ value_states (`torch.Tensor`):
540
+ Input value states to be passed to Flash Attention API
541
+ attention_mask (`torch.Tensor`):
542
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
543
+ position of padding tokens and 1 for the position of non-padding tokens.
544
+ dropout (`int`, *optional*):
545
+ Attention dropout
546
+ softmax_scale (`float`, *optional*):
547
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
548
+ """
549
+ if not self._flash_attn_uses_top_left_mask:
550
+ causal = self.is_causal
551
+ else:
552
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
553
+ causal = self.is_causal and query_length != 1
554
+
555
+ # Contains at least one padding token in the sequence
556
+ if attention_mask is not None:
557
+ batch_size = query_states.shape[0]
558
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
559
+ query_states, key_states, value_states, attention_mask, query_length
560
+ )
561
+
562
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
563
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
564
+
565
+ attn_output_unpad = flash_attn_varlen_func(
566
+ query_states,
567
+ key_states,
568
+ value_states,
569
+ cu_seqlens_q=cu_seqlens_q,
570
+ cu_seqlens_k=cu_seqlens_k,
571
+ max_seqlen_q=max_seqlen_in_batch_q,
572
+ max_seqlen_k=max_seqlen_in_batch_k,
573
+ dropout_p=dropout,
574
+ softmax_scale=softmax_scale,
575
+ causal=causal,
576
+ )
577
+
578
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
579
+ else:
580
+ attn_output = flash_attn_func(
581
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
582
+ )
583
+
584
+ return attn_output
585
+
586
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
587
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
588
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
589
+
590
+ key_layer = index_first_axis(
591
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
592
+ )
593
+ value_layer = index_first_axis(
594
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
595
+ )
596
+ if query_length == kv_seq_len:
597
+ query_layer = index_first_axis(
598
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
599
+ )
600
+ cu_seqlens_q = cu_seqlens_k
601
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
602
+ indices_q = indices_k
603
+ elif query_length == 1:
604
+ max_seqlen_in_batch_q = 1
605
+ cu_seqlens_q = torch.arange(
606
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
607
+ ) # There is a memcpy here, that is very bad.
608
+ indices_q = cu_seqlens_q[:-1]
609
+ query_layer = query_layer.squeeze(1)
610
+ else:
611
+ # The -q_len: slice assumes left padding.
612
+ attention_mask = attention_mask[:, -query_length:]
613
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
614
+
615
+ return (
616
+ query_layer,
617
+ key_layer,
618
+ value_layer,
619
+ indices_q,
620
+ (cu_seqlens_q, cu_seqlens_k),
621
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
622
+ )
623
+
624
+
625
+ class LlamaSdpaAttention(LlamaAttention):
626
+ """
627
+ Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
628
+ `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
629
+ SDPA API.
630
+ """
631
+
632
+ # Adapted from LlamaAttention.forward
633
+ def forward(
634
+ self,
635
+ hidden_states: torch.Tensor,
636
+ attention_mask: Optional[torch.Tensor] = None,
637
+ position_ids: Optional[torch.LongTensor] = None,
638
+ past_key_value: Optional[Cache] = None,
639
+ output_attentions: bool = False,
640
+ use_cache: bool = False,
641
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
642
+ if output_attentions:
643
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
644
+ logger.warning_once(
645
+ "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
646
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
647
+ )
648
+ return super().forward(
649
+ hidden_states=hidden_states,
650
+ attention_mask=attention_mask,
651
+ position_ids=position_ids,
652
+ past_key_value=past_key_value,
653
+ output_attentions=output_attentions,
654
+ use_cache=use_cache,
655
+ )
656
+
657
+ bsz, q_len, _ = hidden_states.size()
658
+
659
+ query_states = self.q_proj(hidden_states)
660
+ key_states = self.k_proj(hidden_states)
661
+ value_states = self.v_proj(hidden_states)
662
+
663
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
664
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
665
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
666
+
667
+ kv_seq_len = key_states.shape[-2]
668
+ if past_key_value is not None:
669
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
670
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
671
+
672
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
673
+
674
+ if past_key_value is not None:
675
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
676
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
677
+
678
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
679
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
680
+
681
+ if attention_mask is not None:
682
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
683
+ raise ValueError(
684
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
685
+ )
686
+
687
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
688
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
689
+ if query_states.device.type == "cuda" and attention_mask is not None:
690
+ query_states = query_states.contiguous()
691
+ key_states = key_states.contiguous()
692
+ value_states = value_states.contiguous()
693
+
694
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
695
+ query_states,
696
+ key_states,
697
+ value_states,
698
+ attn_mask=attention_mask,
699
+ dropout_p=self.attention_dropout if self.training else 0.0,
700
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
701
+ is_causal=self.is_causal and attention_mask is None and q_len > 1,
702
+ )
703
+
704
+ attn_output = attn_output.transpose(1, 2).contiguous()
705
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
706
+
707
+ attn_output = self.o_proj(attn_output)
708
+
709
+ return attn_output, None, past_key_value
710
+
711
+
712
+ LLAMA_ATTENTION_CLASSES = {
713
+ "eager": LlamaAttention,
714
+ "flash_attention_2": LlamaFlashAttention2,
715
+ "sdpa": LlamaSdpaAttention,
716
+ }
717
+
718
+
719
+ class TopKBalancedNoisyGate(nn.Module):
720
+ def __init__(
721
+ self,
722
+ input_size,
723
+ num_experts,
724
+ num_selects,
725
+ gate_network="mlp",
726
+ use_softmax=True,
727
+ use_balance=True,
728
+ balance_loss_weight=1e-2,
729
+ add_noise=True,
730
+ noise_epsilon=1e-2,
731
+ ):
732
+ super(TopKBalancedNoisyGate, self).__init__()
733
+ assert num_selects <= num_experts
734
+ self.input_size = input_size
735
+ self.num_experts = num_experts
736
+ self.num_selects = num_selects
737
+
738
+ self.gate_network_type = gate_network
739
+ self.gate_network = self.get_gate_network(gate_network, input_size, num_experts)
740
+
741
+ self.use_softmax = use_softmax
742
+ self.softmax = nn.Softmax(1)
743
+
744
+ self.use_balance = use_balance
745
+ self.balance_loss_weight = balance_loss_weight
746
+
747
+ # add_noise
748
+ self.add_noise = add_noise
749
+ self.noise_epsilon = noise_epsilon
750
+ self.warned = False
751
+ if self.add_noise:
752
+ self.weight_noise = nn.Linear(input_size, num_experts, bias=False)
753
+ self.weight_noise.weight.data = torch.zeros(
754
+ (num_experts, input_size),
755
+ requires_grad=True,
756
+ device=self.weight_noise.weight.data.device,
757
+ dtype=self.weight_noise.weight.data.dtype,
758
+ )
759
+ self.mean = 0.0
760
+ self.std = 1.0
761
+ self.normal = Normal(self.mean, self.std)
762
+ self.softplus = nn.Softplus()
763
+
764
+ self.reset_parameters()
765
+
766
+ def get_gate_network(self, gate_type, input_size, num_experts):
767
+ gate_type = gate_type.lower()
768
+
769
+ if gate_type == "linear":
770
+ gate_network = nn.Linear(input_size, num_experts, bias=False)
771
+ nn.init.zeros_(gate_network.weight)
772
+ elif gate_type == "mlp":
773
+ gate_network = torch.nn.Sequential(
774
+ torch.nn.Linear(input_size, num_experts, bias=False),
775
+ torch.nn.Tanh(),
776
+ torch.nn.Linear(num_experts, num_experts, bias=False),
777
+ )
778
+ else:
779
+ raise ValueError(f'Unexpected gate_type: {gate_type}.')
780
+
781
+ return gate_network
782
+
783
+ def reset_gate_network(self):
784
+ if "gate_network_type" not in vars(self):
785
+ raise KeyError(f"{type(self)} does not have a gate network.")
786
+ else:
787
+ self.gate_network = self.get_gate_network(
788
+ self.gate_network_type, self.input_size, self.num_experts
789
+ )
790
+
791
+ def reset_parameters(self):
792
+ if self.add_noise:
793
+ nn.init.zeros_(self.weight_noise.weight)
794
+ # nn.init.zeros_(self.weight_noise)
795
+
796
+ def cv_squared(self, x, eps=1e-10):
797
+ """The squared coefficient of variation of a sample.
798
+ Useful as a loss to encourage a positive distribution to be more uniform.
799
+ Epsilons added for numerical stability.
800
+ Returns 0 for an empty Tensor.
801
+ Args:
802
+ x: a `Tensor`.
803
+ Returns:
804
+ a `Scalar`.s
805
+ """
806
+ if x.shape[0] == 1:
807
+ return torch.tensor(0.0, device=x.device)
808
+ return x.float().var() / (x.float().mean() ** 2 + eps)
809
+
810
+ def forward(self, x):
811
+ logits_gate = self.gate_network(x)
812
+ if self.training and self.add_noise:
813
+ noise_mm = self.weight_noise(x)
814
+ noise_control = self.softplus(noise_mm) + self.noise_epsilon
815
+ logits_noise = torch.randn_like(logits_gate) * noise_control
816
+ logits = logits_gate + logits_noise
817
+ else:
818
+ logits = logits_gate
819
+
820
+ top_logits, top_indices = logits.topk(min(self.num_selects + 1, self.num_experts), dim=1) # 选择并排序前k+1个权重
821
+ top_k_logits = top_logits[:, :self.num_selects]
822
+ top_k_indices = top_indices[:, :self.num_selects]
823
+ top_k_scores = self.softmax(top_k_logits.to(torch.float32)) if self.use_softmax else top_k_logits
824
+ top_k_scores = top_k_scores.to(logits.dtype)
825
+
826
+ zeros = torch.zeros_like(logits, requires_grad=True, device=logits.device)
827
+ scores_filtered = zeros.scatter(dim=1, index=top_k_indices, src=top_k_scores) # shape(batch_size, num_experts)
828
+ importance = scores_filtered.sum(0) # shape(num_experts)
829
+
830
+ if self.training:
831
+ if self.add_noise and self.num_selects != self.num_experts:
832
+ batch_size = top_logits.size(0)
833
+ m = top_logits.size(1)
834
+ top_values_flat = top_logits.flatten()
835
+ threshold_positions_if_in = torch.arange(batch_size, device=x.device) * m + self.num_selects
836
+ threshold_if_in = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_in), 1)
837
+ is_in = torch.gt(logits_noise, threshold_if_in)
838
+ threshold_positions_if_out = threshold_positions_if_in - 1
839
+ threshold_if_out = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_out), 1)
840
+ # is each value currently in the top k.
841
+ prob_if_in = self.normal.cdf((logits_gate - threshold_if_in) / noise_control)
842
+ prob_if_out = self.normal.cdf((logits_gate - threshold_if_out) / noise_control)
843
+ prob = torch.where(is_in, prob_if_in, prob_if_out)
844
+ load = prob.sum(0)
845
+ else:
846
+ load = (scores_filtered > 0).sum(0)
847
+ if not self.add_noise and not self.warned:
848
+ warnings.warn('Gradient-trackable implementation for load calculation is only available when "add_noise=True". '
849
+ 'Training without noise will block the gradient from "load" path and lead to inconsistency in optimization objectives.')
850
+ self.warned = True
851
+ else:
852
+ load = (scores_filtered > 0).sum(0)
853
+
854
+ if self.use_balance:
855
+ balance_loss = self.cv_squared(importance) + self.cv_squared(load)
856
+ balance_loss *= self.balance_loss_weight
857
+ else:
858
+ balance_loss = torch.tensor(-100.0, device=x.device)
859
+
860
+ return {
861
+ "topK_indices": top_k_indices,
862
+ "topK_scores": top_k_scores,
863
+ "balance_loss": balance_loss,
864
+ "load": load,
865
+ "importance": importance,
866
+ }
867
+
868
+
869
+ class LinearGLUExperts(nn.Module):
870
+ """
871
+ Modified from transformers.models.llama.modeling_llama.LlamaMLP
872
+ """
873
+
874
+ __constants__ = [
875
+ "bias",
876
+ "in_features",
877
+ "hidden_features",
878
+ "out_features",
879
+ "hidden_act",
880
+ "num_experts",
881
+ "size_experts",
882
+ ]
883
+
884
+ def __init__(
885
+ self,
886
+ in_features,
887
+ hidden_features,
888
+ out_features,
889
+ hidden_act,
890
+ num_experts,
891
+ size_experts=None,
892
+ bias=True,
893
+ device=None,
894
+ dtype=None,
895
+ ):
896
+ factory_kwargs = {"device": device, "dtype": dtype}
897
+ super(LinearGLUExperts, self).__init__()
898
+ self.in_features = in_features
899
+ self.hidden_features = hidden_features
900
+ self.out_features = out_features
901
+ self.hidden_act = hidden_act
902
+ self.num_experts = num_experts
903
+
904
+ if size_experts is None:
905
+ # all experts share the same number of hidden neurons
906
+ assert hidden_features % num_experts == 0
907
+ size_per_expert = hidden_features // num_experts
908
+ size_experts = [size_per_expert for _ in range(num_experts)]
909
+ else:
910
+ # use specified expert sizes
911
+ assert (
912
+ len(size_experts) == num_experts
913
+ and sum(size_experts) == hidden_features
914
+ )
915
+ self.size_experts = size_experts
916
+
917
+ self.act_fn = ACT2FN[hidden_act]
918
+
919
+ self.weight_gate = nn.ParameterList()
920
+ self.weight_up = nn.ParameterList()
921
+ self.weight_down = nn.ParameterList()
922
+
923
+ for i in range(num_experts):
924
+ # this matrix will be transposed when performing linear forwarding
925
+ this_expert_weight_gate = nn.Parameter(
926
+ torch.empty((size_experts[i], in_features), **factory_kwargs)
927
+ )
928
+ # this matrix will be transposed when performing linear forwarding
929
+ this_expert_weight_up = nn.Parameter(
930
+ torch.empty((size_experts[i], in_features), **factory_kwargs)
931
+ )
932
+ # this matrix will be transposed when performing linear forwarding
933
+ this_expert_weight_down = nn.Parameter(
934
+ torch.empty((out_features, size_experts[i]), **factory_kwargs)
935
+ )
936
+ self.weight_gate.append(this_expert_weight_gate)
937
+ self.weight_up.append(this_expert_weight_up)
938
+ self.weight_down.append(this_expert_weight_down)
939
+
940
+ if bias:
941
+ self.bias_gate = nn.ParameterList()
942
+ self.bias_up = nn.ParameterList()
943
+ self.bias_down = nn.ParameterList()
944
+
945
+ for i in range(num_experts):
946
+ this_expert_bias_gate = nn.Parameter(
947
+ torch.empty((size_experts[i],), **factory_kwargs)
948
+ )
949
+ this_expert_bias_up = nn.Parameter(
950
+ torch.empty((size_experts[i],), **factory_kwargs)
951
+ )
952
+ this_expert_bias_down = nn.Parameter(
953
+ torch.empty((out_features,), **factory_kwargs)
954
+ )
955
+ self.bias_gate.append(this_expert_bias_gate)
956
+ self.bias_up.append(this_expert_bias_up)
957
+ self.bias_down.append(this_expert_bias_down)
958
+ else:
959
+ self.register_parameter("bias_gate", None)
960
+ self.register_parameter("bias_up", None)
961
+ self.register_parameter("bias_down", None)
962
+
963
+ self.reset_parameters()
964
+
965
+ def reset_parameters(self):
966
+ for i in range(self.num_experts):
967
+ nn.init.kaiming_uniform_(self.weight_gate[i], a=math.sqrt(5))
968
+ nn.init.kaiming_uniform_(self.weight_up[i], a=math.sqrt(5))
969
+ nn.init.kaiming_uniform_(self.weight_down[i], a=math.sqrt(5))
970
+ if self.bias_gate is not None:
971
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight_gate[i])
972
+ bound = 1 / math.sqrt(fan_in)
973
+ nn.init.uniform_(self.bias_gate[i], -bound, bound)
974
+ if self.bias_up is not None:
975
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight_up[i])
976
+ bound = 1 / math.sqrt(fan_in)
977
+ nn.init.uniform_(self.bias_up[i], -bound, bound)
978
+ if self.bias_down is not None:
979
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight_down[i])
980
+ bound = 1 / math.sqrt(fan_in)
981
+ nn.init.uniform_(self.bias_down[i], -bound, bound)
982
+
983
+ def forward(self, input, i):
984
+ gate = self.act_fn(
985
+ F.linear(
986
+ input,
987
+ self.weight_gate[i],
988
+ self.bias_gate[i] if self.bias_gate is not None else None,
989
+ )
990
+ )
991
+ up = F.linear(
992
+ input,
993
+ self.weight_up[i],
994
+ self.bias_up[i] if self.bias_up is not None else None,
995
+ )
996
+ down = F.linear(
997
+ gate * up,
998
+ self.weight_down[i],
999
+ self.bias_down[i] if self.bias_down is not None else None,
1000
+ )
1001
+ return down
1002
+
1003
+ def extra_repr(self):
1004
+ return (
1005
+ "in_features={}, hidden_features={}, out_features={}, hidden_act={},"
1006
+ " num_experts={}, size_experts={}, bias={}".format(
1007
+ self.in_features,
1008
+ self.hidden_features,
1009
+ self.out_features,
1010
+ self.hidden_act,
1011
+ self.num_experts,
1012
+ self.size_experts,
1013
+ self.bias_gate is not None,
1014
+ )
1015
+ )
1016
+
1017
+
1018
+ class UniversalCalculator(nn.Module):
1019
+ def __init__(
1020
+ self,
1021
+ experts: LinearGLUExperts,
1022
+ multiply_gate_scores=True,
1023
+ score_scale_factor=1.0,
1024
+ add_weight_norm: bool = False,
1025
+ ):
1026
+ super(UniversalCalculator, self).__init__()
1027
+ self.experts = experts
1028
+ # TODO (zhutong): use vmap to boost the training efficiency
1029
+ # self.experts_vmap = torch.vmap(self.experts)
1030
+ self.multiply_gate_scores = multiply_gate_scores
1031
+ self.score_scale_factor = score_scale_factor
1032
+ self.num_experts = experts.num_experts
1033
+ self.mlp_norm = None
1034
+ if multiply_gate_scores and add_weight_norm:
1035
+ raise NotImplementedError
1036
+
1037
+ def reset_experts(self):
1038
+ self.experts.reset_parameters()
1039
+
1040
+ def forward(
1041
+ self, x, topK_indices, topK_scores, expert_batch_size=None, **kwargs
1042
+ ) -> CalculatorOutput:
1043
+ batch_size = topK_indices.size(0) # topK_indices: (bsz*seq_len, num_selects)
1044
+ num_selects = topK_indices.size(1)
1045
+ topK_indices = topK_indices.flatten() # shape(batch_size*num_selects)
1046
+ topK_scores = topK_scores.flatten() # shape(batch_size*num_selects)
1047
+ batch_indices = torch.arange(
1048
+ batch_size, device=topK_scores.device
1049
+ ).repeat_interleave(num_selects)
1050
+
1051
+ _, index_sorted_topK_indices = topK_indices.sort(0)
1052
+
1053
+ sorted_topK_scores = topK_scores.index_select(0, index_sorted_topK_indices)
1054
+ sorted_batch_indices = batch_indices.index_select(0, index_sorted_topK_indices)
1055
+
1056
+ if expert_batch_size is None:
1057
+ expert_batch_size = topK_indices.bincount(
1058
+ minlength=self.num_experts
1059
+ ).tolist()
1060
+
1061
+ sorted_x = x.index_select(0, sorted_batch_indices)
1062
+ split_x = torch.split(sorted_x, expert_batch_size, dim=0)
1063
+
1064
+ expert_outputs = [
1065
+ self.experts(split_x[i], i)
1066
+ for i in range(self.num_experts)
1067
+ if split_x[i].shape[0] > 0
1068
+ ]
1069
+
1070
+ # (bsz*seq_len*num_selects, hidden_size)
1071
+ cat_expert_outputs = torch.cat(expert_outputs, 0)
1072
+ output_dim = cat_expert_outputs.size(1)
1073
+ if self.multiply_gate_scores:
1074
+ if self.mlp_norm is None:
1075
+ cat_expert_outputs = torch.mul(
1076
+ cat_expert_outputs,
1077
+ sorted_topK_scores.reshape(-1, 1) * self.score_scale_factor,
1078
+ )
1079
+ # cat_expert_outputs = torch.mul(cat_expert_outputs, sorted_topK_scores.reshape(-1, 1) * 1.0)
1080
+ else:
1081
+ cat_expert_outputs = torch.mul(
1082
+ cat_expert_outputs, sorted_topK_scores.reshape(-1, 1)
1083
+ )
1084
+ cat_expert_outputs = self.mlp_norm(cat_expert_outputs)
1085
+
1086
+ zeros = torch.zeros(
1087
+ (batch_size, output_dim),
1088
+ device=cat_expert_outputs.device,
1089
+ dtype=cat_expert_outputs.dtype,
1090
+ )
1091
+ y = zeros.index_add(0, sorted_batch_indices, cat_expert_outputs)
1092
+
1093
+ return CalculatorOutput(hidden_states=y, num_dropped_tokens=torch.tensor(-1.0))
1094
+
1095
+
1096
+ class BaseMoELayer(nn.Module):
1097
+ def __init__(self):
1098
+ super(BaseMoELayer, self).__init__()
1099
+
1100
+ self.gate: TopKBalancedNoisyGate
1101
+ self.calculator: UniversalCalculator
1102
+
1103
+ def _create_gate(self, **kwargs):
1104
+ self.gate_type = kwargs.get("gate_type", "TopKBalancedNoisyGate")
1105
+
1106
+ if self.gate_type == "TopKBalancedNoisyGate": # noisy gate
1107
+ self.gate = TopKBalancedNoisyGate(
1108
+ self.input_size,
1109
+ self.num_experts,
1110
+ self.num_selects,
1111
+ gate_network=kwargs.get("gate_network", "mlp"),
1112
+ use_softmax=kwargs.get("gate_use_softmax", True),
1113
+ use_balance=kwargs.get("gate_use_balance", True),
1114
+ balance_loss_weight=kwargs.get("gate_balance_loss_weight", 1e-2),
1115
+ add_noise=kwargs.get("gate_add_noise", True),
1116
+ noise_epsilon=kwargs.get("gate_noise_epsilon", 1e-2),
1117
+ )
1118
+ else:
1119
+ raise NotImplementedError
1120
+
1121
+ def _create_calculator(self, experts, **kwargs):
1122
+ self.calculator_type = kwargs.get("calculator_type", "UniversalCalculator")
1123
+
1124
+ if self.calculator_type == "UniversalCalculator": # top K calculator
1125
+ self.calculator = UniversalCalculator(
1126
+ experts,
1127
+ multiply_gate_scores=kwargs.get("multiply_gate_scores", True),
1128
+ score_scale_factor=kwargs.get("score_scale_factor", 1.0),
1129
+ add_weight_norm=kwargs.get("add_weight_norm", False),
1130
+ )
1131
+ else:
1132
+ raise NotImplementedError
1133
+
1134
+ def forward(self, x, attention_mask=None) -> MoEMlpOutput:
1135
+ original_shape = x.shape[:-1]
1136
+ x = x.reshape(-1, self.input_size)
1137
+ flattened_mask = None
1138
+ if attention_mask is not None and len(attention_mask.shape) == 2:
1139
+ flattened_mask = attention_mask.flatten()
1140
+ flattened_shape = flattened_mask.shape
1141
+ x = x[flattened_mask.bool()]
1142
+
1143
+ gate_outputs: dict = self.gate(x)
1144
+ calc_outs: CalculatorOutput = self.calculator(x, **gate_outputs)
1145
+
1146
+ y = calc_outs.hidden_states
1147
+ if flattened_mask is not None:
1148
+ y = torch.zeros(flattened_shape + (self.output_size,), dtype=x.dtype, device=x.device) # (batch_size*seq_len, output_size)
1149
+ y[flattened_mask.bool()] = calc_outs.hidden_states # (non_padding_num, output_size)
1150
+ y = y.reshape(original_shape + (self.output_size,))
1151
+
1152
+ return MoEMlpOutput(
1153
+ hidden_states=y,
1154
+ balance_loss=gate_outputs.get("balance_loss"),
1155
+ num_dropped_tokens=calc_outs.num_dropped_tokens,
1156
+ gate_load=gate_outputs.get("load", torch.tensor(-1)),
1157
+ gate_importance=gate_outputs.get("importance", torch.tensor(-1)),
1158
+ )
1159
+
1160
+ def reset_gate_network(self):
1161
+ self.gate.reset_gate_network()
1162
+
1163
+ def reset_experts(self):
1164
+ self.calculator.reset_experts()
1165
+
1166
+
1167
+ class LinearGLUMoELayer(BaseMoELayer):
1168
+ def __init__(
1169
+ self,
1170
+ input_size,
1171
+ hidden_size,
1172
+ output_size,
1173
+ hidden_act,
1174
+ num_experts,
1175
+ num_selects,
1176
+ size_experts=None,
1177
+ bias=True,
1178
+ **kwargs,
1179
+ ):
1180
+ super(LinearGLUMoELayer, self).__init__()
1181
+ assert num_selects <= num_experts
1182
+ self.input_size = input_size
1183
+ self.hidden_size = hidden_size
1184
+ self.output_size = output_size
1185
+ self.hidden_act = hidden_act
1186
+ self.num_experts = num_experts
1187
+ self.num_selects = num_selects
1188
+ self.size_experts = size_experts
1189
+ self.bias = bias
1190
+
1191
+ experts = LinearGLUExperts(
1192
+ input_size,
1193
+ hidden_size,
1194
+ output_size,
1195
+ hidden_act,
1196
+ num_experts,
1197
+ size_experts=size_experts,
1198
+ bias=bias,
1199
+ )
1200
+
1201
+ self._create_gate(**kwargs)
1202
+ self._create_calculator(experts, **kwargs)
1203
+
1204
+
1205
+ class LlamaMoEDecoderLayer(nn.Module):
1206
+ def __init__(self, config: LlamaMoEConfig, layer_index):
1207
+ super().__init__()
1208
+
1209
+ self.hidden_size = config.hidden_size
1210
+ # self.self_attn = LlamaAttention(config=config)
1211
+ self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_index)
1212
+
1213
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1214
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1215
+
1216
+ gating_config = {
1217
+ # all gates
1218
+ "gate_type": config.gate_type,
1219
+ "gate_network": config.gate_network,
1220
+ "gate_use_softmax": config.gate_use_softmax,
1221
+ "gate_use_balance": config.gate_use_balance,
1222
+ "gate_balance_loss_weight": config.gate_balance_loss_weight,
1223
+ "gate_add_noise": config.gate_add_noise,
1224
+ # TopKBalancedNoisyGate
1225
+ "gate_noise_epsilon": config.gate_noise_epsilon,
1226
+ }
1227
+ calculator_config = {
1228
+ # all calculators
1229
+ "calculator_type": config.calculator_type,
1230
+ "multiply_gate_scores": config.multiply_gate_scores,
1231
+ "score_scale_factor": (
1232
+ config.score_scale_factor[layer_index]
1233
+ if isinstance(config.score_scale_factor, list)
1234
+ else config.score_scale_factor
1235
+ ),
1236
+ "add_weight_norm": config.add_weight_norm,
1237
+ # SwitchDropTokenCalculator
1238
+ "drop_tokens": config.drop_tokens,
1239
+ "dropped_padding": config.dropped_padding,
1240
+ "capacity_factor": config.capacity_factor,
1241
+ }
1242
+
1243
+ self.mlp = LinearGLUMoELayer(
1244
+ input_size=self.hidden_size,
1245
+ hidden_size=config.intermediate_size,
1246
+ output_size=self.hidden_size,
1247
+ hidden_act=config.hidden_act,
1248
+ num_experts=config.num_experts,
1249
+ num_selects=config.num_selects,
1250
+ size_experts=(
1251
+ config.size_experts[layer_index]
1252
+ if config.size_experts is not None
1253
+ else None
1254
+ ),
1255
+ bias=False,
1256
+ **gating_config,
1257
+ **calculator_config,
1258
+ )
1259
+
1260
+ def forward(
1261
+ self,
1262
+ hidden_states,
1263
+ attention_mask=None,
1264
+ position_ids=None,
1265
+ past_key_value=None,
1266
+ output_attentions=False,
1267
+ use_cache=False,
1268
+ ) -> tuple:
1269
+ residual = hidden_states
1270
+ hidden_states = self.input_layernorm(hidden_states)
1271
+
1272
+ # Self Attention
1273
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
1274
+ hidden_states=hidden_states,
1275
+ attention_mask=attention_mask,
1276
+ position_ids=position_ids,
1277
+ past_key_value=past_key_value,
1278
+ output_attentions=output_attentions,
1279
+ use_cache=use_cache,
1280
+ )
1281
+ hidden_states = residual + hidden_states
1282
+
1283
+ # Fully Connected
1284
+ residual = hidden_states
1285
+ hidden_states = self.post_attention_layernorm(hidden_states)
1286
+ mlp_outs: MoEMlpOutput = self.mlp(hidden_states, attention_mask=attention_mask)
1287
+ hidden_states = residual + mlp_outs.hidden_states
1288
+
1289
+ outputs = (
1290
+ hidden_states,
1291
+ mlp_outs.balance_loss,
1292
+ mlp_outs.num_dropped_tokens,
1293
+ mlp_outs.gate_load,
1294
+ mlp_outs.gate_importance,
1295
+ )
1296
+ if output_attentions:
1297
+ outputs += (self_attn_weights,)
1298
+ if use_cache:
1299
+ outputs += (present_key_value,)
1300
+
1301
+ return outputs
1302
+
1303
+
1304
+ class LlamaMoEPreTrainedModel(PreTrainedModel):
1305
+ config_class = LlamaMoEConfig
1306
+ base_model_prefix = "model"
1307
+ supports_gradient_checkpointing = True
1308
+ _no_split_modules = ["LlamaMoEDecoderLayer"]
1309
+ _skip_keys_device_placement = "past_key_values"
1310
+ _supports_flash_attn_2 = True
1311
+
1312
+ def _init_weights(self, module):
1313
+ std = self.config.initializer_range
1314
+ if isinstance(module, nn.Linear):
1315
+ module.weight.data.normal_(mean=0.0, std=std)
1316
+ if module.bias is not None:
1317
+ module.bias.data.zero_()
1318
+ elif isinstance(module, nn.Embedding):
1319
+ module.weight.data.normal_(mean=0.0, std=std)
1320
+ if module.padding_idx is not None:
1321
+ module.weight.data[module.padding_idx].zero_()
1322
+
1323
+
1324
+ class LlamaMoEModel(LlamaMoEPreTrainedModel):
1325
+ def __init__(self, config: LlamaMoEConfig):
1326
+ super().__init__(config)
1327
+ self.padding_idx = config.pad_token_id
1328
+ self.vocab_size = config.vocab_size
1329
+
1330
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1331
+ self.layers = nn.ModuleList(
1332
+ [LlamaMoEDecoderLayer(config, i) for i in range(config.num_hidden_layers)]
1333
+ )
1334
+ self._use_sdpa = config._attn_implementation == "sdpa"
1335
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
1336
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1337
+ self.gradient_checkpointing = False
1338
+ self.post_init()
1339
+
1340
+ def get_input_embeddings(self):
1341
+ return self.embed_tokens
1342
+
1343
+ def set_input_embeddings(self, value):
1344
+ self.embed_tokens = value
1345
+
1346
+ def forward(
1347
+ self,
1348
+ input_ids=None,
1349
+ attention_mask=None,
1350
+ position_ids=None,
1351
+ past_key_values=None,
1352
+ inputs_embeds=None,
1353
+ use_cache=None,
1354
+ output_attentions=None,
1355
+ output_hidden_states=None,
1356
+ return_dict=None,
1357
+ ):
1358
+ output_attentions = (
1359
+ output_attentions
1360
+ if output_attentions is not None
1361
+ else self.config.output_attentions
1362
+ )
1363
+ output_hidden_states = (
1364
+ output_hidden_states
1365
+ if output_hidden_states is not None
1366
+ else self.config.output_hidden_states
1367
+ )
1368
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1369
+
1370
+ return_dict = (
1371
+ return_dict if return_dict is not None else self.config.use_return_dict
1372
+ )
1373
+
1374
+ # retrieve input_ids and inputs_embeds
1375
+ if input_ids is not None and inputs_embeds is not None:
1376
+ raise ValueError(
1377
+ "You cannot specify both decoder_input_ids and decoder_inputs_embeds at"
1378
+ " the same time"
1379
+ )
1380
+ elif input_ids is not None:
1381
+ batch_size, seq_length = input_ids.shape
1382
+ elif inputs_embeds is not None:
1383
+ batch_size, seq_length, _ = inputs_embeds.shape
1384
+ else:
1385
+ raise ValueError(
1386
+ "You have to specify either decoder_input_ids or decoder_inputs_embeds"
1387
+ )
1388
+
1389
+ if self.gradient_checkpointing and self.training:
1390
+ if use_cache:
1391
+ logger.warning_once(
1392
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1393
+ )
1394
+ use_cache = False
1395
+
1396
+ past_key_values_length = 0
1397
+ if use_cache:
1398
+ use_legacy_cache = not isinstance(past_key_values, Cache)
1399
+ if use_legacy_cache:
1400
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1401
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
1402
+
1403
+ if position_ids is None:
1404
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1405
+ position_ids = torch.arange(
1406
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
1407
+ )
1408
+ position_ids = position_ids.unsqueeze(0)
1409
+
1410
+ if inputs_embeds is None:
1411
+ inputs_embeds = self.embed_tokens(input_ids)
1412
+
1413
+ if self._use_flash_attention_2:
1414
+ # 2d mask is passed through the layers
1415
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1416
+ elif self._use_sdpa and not output_attentions:
1417
+ # output_attentions=True can not be supported when using SDPA, and we fall back on
1418
+ # the manual implementation that requires a 4D causal mask in all cases.
1419
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1420
+ attention_mask,
1421
+ (batch_size, seq_length),
1422
+ inputs_embeds,
1423
+ past_key_values_length,
1424
+ )
1425
+ else:
1426
+ # 4d mask is passed through the layers
1427
+ attention_mask = _prepare_4d_causal_attention_mask(
1428
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
1429
+ )
1430
+
1431
+ hidden_states = inputs_embeds
1432
+ balance_loss = 0.0
1433
+
1434
+ # decoder layers
1435
+ all_hidden_states = () if output_hidden_states else None
1436
+ all_self_attns = () if output_attentions else None
1437
+ next_decoder_cache = None
1438
+
1439
+ num_dropped_tokens = ()
1440
+ gate_load = ()
1441
+ gate_importance = ()
1442
+ for idx, decoder_layer in enumerate(self.layers):
1443
+ if output_hidden_states:
1444
+ all_hidden_states += (hidden_states,)
1445
+
1446
+ if self.gradient_checkpointing and self.training:
1447
+ layer_outputs = self._gradient_checkpointing_func(
1448
+ decoder_layer.__call__,
1449
+ hidden_states,
1450
+ attention_mask,
1451
+ position_ids,
1452
+ past_key_values,
1453
+ output_attentions,
1454
+ use_cache,
1455
+ )
1456
+ else:
1457
+ layer_outputs = decoder_layer(
1458
+ hidden_states,
1459
+ attention_mask=attention_mask,
1460
+ position_ids=position_ids,
1461
+ past_key_value=past_key_values,
1462
+ output_attentions=output_attentions,
1463
+ use_cache=use_cache,
1464
+ )
1465
+
1466
+ hidden_states = layer_outputs[0]
1467
+ if layer_outputs[1] is not None:
1468
+ balance_loss += layer_outputs[1]
1469
+
1470
+ if use_cache:
1471
+ next_decoder_cache = layer_outputs[6 if output_attentions else 5]
1472
+
1473
+ if output_attentions:
1474
+ all_self_attns += (layer_outputs[5],)
1475
+
1476
+ num_dropped_tokens += (layer_outputs[2],)
1477
+ gate_load += (layer_outputs[3],)
1478
+ gate_importance += (layer_outputs[4],)
1479
+
1480
+ hidden_states = self.norm(hidden_states)
1481
+
1482
+ # add hidden states from the last decoder layer
1483
+ if output_hidden_states:
1484
+ all_hidden_states += (hidden_states,)
1485
+
1486
+ next_cache = None
1487
+ if use_cache:
1488
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
1489
+ if not return_dict:
1490
+ return tuple(
1491
+ v
1492
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
1493
+ if v is not None
1494
+ )
1495
+ return BaseMoEModelOutputWithPast(
1496
+ last_hidden_state=hidden_states,
1497
+ balance_loss=balance_loss,
1498
+ past_key_values=next_cache,
1499
+ hidden_states=all_hidden_states,
1500
+ attentions=all_self_attns,
1501
+ num_dropped_tokens=num_dropped_tokens,
1502
+ gate_load=gate_load,
1503
+ gate_importance=gate_importance,
1504
+ )
1505
+
1506
+ def reset_gate_network(self):
1507
+ for idx, decoder_layer in enumerate(self.layers):
1508
+ decoder_layer.reset_gate_network()
1509
+
1510
+ def reset_experts(self):
1511
+ for idx, decoder_layer in enumerate(self.layers):
1512
+ decoder_layer.reset_experts()
1513
+
1514
+
1515
+ class LlamaMoEForCausalLM(LlamaMoEPreTrainedModel):
1516
+ _tied_weights_keys = ["lm_head.weight"]
1517
+
1518
+ def __init__(self, config):
1519
+ super().__init__(config)
1520
+ self.model = LlamaMoEModel(config)
1521
+ self.pretraining_tp = config.pretraining_tp
1522
+ self.vocab_size = config.vocab_size
1523
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1524
+
1525
+ # Initialize weights and apply final processing
1526
+ self.post_init()
1527
+
1528
+ def get_input_embeddings(self):
1529
+ return self.model.embed_tokens
1530
+
1531
+ def set_input_embeddings(self, value):
1532
+ self.model.embed_tokens = value
1533
+
1534
+ def get_output_embeddings(self):
1535
+ return self.lm_head
1536
+
1537
+ def set_output_embeddings(self, new_embeddings):
1538
+ self.lm_head = new_embeddings
1539
+
1540
+ def set_decoder(self, decoder):
1541
+ self.model = decoder
1542
+
1543
+ def get_decoder(self):
1544
+ return self.model
1545
+
1546
+ def forward(
1547
+ self,
1548
+ input_ids=None,
1549
+ attention_mask=None,
1550
+ position_ids=None,
1551
+ past_key_values=None,
1552
+ inputs_embeds=None,
1553
+ labels=None,
1554
+ use_cache=None,
1555
+ output_attentions=None,
1556
+ output_hidden_states=None,
1557
+ return_dict=None,
1558
+ **kwargs,
1559
+ ):
1560
+ output_attentions = (
1561
+ output_attentions
1562
+ if output_attentions is not None
1563
+ else self.config.output_attentions
1564
+ )
1565
+ output_hidden_states = (
1566
+ output_hidden_states
1567
+ if output_hidden_states is not None
1568
+ else self.config.output_hidden_states
1569
+ )
1570
+ return_dict = (
1571
+ return_dict if return_dict is not None else self.config.use_return_dict
1572
+ )
1573
+
1574
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1575
+ outputs: BaseMoEModelOutputWithPast = self.model(
1576
+ input_ids=input_ids,
1577
+ attention_mask=attention_mask,
1578
+ position_ids=position_ids,
1579
+ past_key_values=past_key_values,
1580
+ inputs_embeds=inputs_embeds,
1581
+ use_cache=use_cache,
1582
+ output_attentions=output_attentions,
1583
+ output_hidden_states=output_hidden_states,
1584
+ return_dict=return_dict,
1585
+ )
1586
+
1587
+ hidden_states = outputs.last_hidden_state
1588
+ logits = self.lm_head(hidden_states)
1589
+
1590
+ loss = None
1591
+ if labels is not None:
1592
+ # Shift so that tokens < n predict n
1593
+ shift_logits = logits[..., :-1, :].contiguous()
1594
+ shift_labels = labels[..., 1:].contiguous()
1595
+ # Flatten the tokens
1596
+ loss_fct = nn.CrossEntropyLoss()
1597
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1598
+ shift_labels = shift_labels.view(-1)
1599
+ # Enable model parallelism
1600
+ shift_labels = shift_labels.to(shift_logits.device)
1601
+ loss = loss_fct(shift_logits, shift_labels)
1602
+ if outputs.balance_loss is not None and outputs.balance_loss > 0:
1603
+ loss += outputs.balance_loss
1604
+
1605
+ if not return_dict:
1606
+ output = (logits,) + outputs[1:]
1607
+ return (loss,) + output if loss is not None else output
1608
+
1609
+ return MoECausalLMOutputWithPast(
1610
+ loss=loss,
1611
+ logits=logits,
1612
+ past_key_values=outputs.past_key_values,
1613
+ hidden_states=outputs.hidden_states,
1614
+ attentions=outputs.attentions,
1615
+ num_dropped_tokens=outputs.num_dropped_tokens,
1616
+ balance_loss=outputs.balance_loss,
1617
+ gate_load=outputs.gate_load,
1618
+ gate_importance=outputs.gate_importance,
1619
+ )
1620
+
1621
+ def prepare_inputs_for_generation(
1622
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1623
+ ):
1624
+ if past_key_values is not None:
1625
+ if isinstance(past_key_values, Cache):
1626
+ cache_length = past_key_values.get_seq_length()
1627
+ past_length = past_key_values.seen_tokens
1628
+ max_cache_length = past_key_values.get_max_length()
1629
+ else:
1630
+ cache_length = past_length = past_key_values[0][0].shape[2]
1631
+ max_cache_length = None
1632
+
1633
+ # Keep only the unprocessed tokens:
1634
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1635
+ # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
1636
+ # input)
1637
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1638
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1639
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1640
+ # input_ids based on the past_length.
1641
+ elif past_length < input_ids.shape[1]:
1642
+ input_ids = input_ids[:, past_length:]
1643
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1644
+
1645
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1646
+ if (
1647
+ max_cache_length is not None
1648
+ and attention_mask is not None
1649
+ and cache_length + input_ids.shape[1] > max_cache_length
1650
+ ):
1651
+ attention_mask = attention_mask[:, -max_cache_length:]
1652
+
1653
+ position_ids = kwargs.get("position_ids", None)
1654
+ if attention_mask is not None and position_ids is None:
1655
+ # create position_ids on the fly for batch generation
1656
+ position_ids = attention_mask.long().cumsum(-1) - 1
1657
+ position_ids.masked_fill_(attention_mask == 0, 1)
1658
+ if past_key_values:
1659
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1660
+
1661
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1662
+ if inputs_embeds is not None and past_key_values is None:
1663
+ model_inputs = {"inputs_embeds": inputs_embeds}
1664
+ else:
1665
+ model_inputs = {"input_ids": input_ids}
1666
+
1667
+ model_inputs.update(
1668
+ {
1669
+ "position_ids": position_ids,
1670
+ "past_key_values": past_key_values,
1671
+ "use_cache": kwargs.get("use_cache"),
1672
+ "attention_mask": attention_mask,
1673
+ }
1674
+ )
1675
+ return model_inputs
1676
+
1677
+ @staticmethod
1678
+ def _reorder_cache(past_key_values, beam_idx):
1679
+ reordered_past = ()
1680
+ for layer_past in past_key_values:
1681
+ reordered_past += (
1682
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1683
+ )
1684
+ return reordered_past
1685
+
1686
+ def reset_gate_network(self):
1687
+ self.model.reset_gate_network()
1688
+
1689
+ def reset_experts(self):
1690
+ self.model.reset_experts()
sampling_info/100/load.pdf ADDED
Binary file (12.6 kB). View file
 
sampling_info/100/prob_map.pdf ADDED
Binary file (13.5 kB). View file
 
sampling_info/100/sim.pdf ADDED
Binary file (11.5 kB). View file
 
sampling_info/1000/load.pdf ADDED
Binary file (12.6 kB). View file
 
sampling_info/1000/prob_map.pdf ADDED
Binary file (13.6 kB). View file
 
sampling_info/1000/sim.pdf ADDED
Binary file (11.5 kB). View file
 
sampling_info/1100/load.pdf ADDED
Binary file (12.6 kB). View file
 
sampling_info/1100/prob_map.pdf ADDED
Binary file (13.6 kB). View file
 
sampling_info/1100/sim.pdf ADDED
Binary file (11.5 kB). View file
 
sampling_info/1200/load.pdf ADDED
Binary file (12.6 kB). View file
 
sampling_info/1200/prob_map.pdf ADDED
Binary file (13.6 kB). View file
 
sampling_info/1200/sim.pdf ADDED
Binary file (11.5 kB). View file
 
sampling_info/1300/load.pdf ADDED
Binary file (12.6 kB). View file
 
sampling_info/1300/prob_map.pdf ADDED
Binary file (13.6 kB). View file
 
sampling_info/1300/sim.pdf ADDED
Binary file (11.5 kB). View file
 
sampling_info/1400/load.pdf ADDED
Binary file (12.6 kB). View file
 
sampling_info/1400/prob_map.pdf ADDED
Binary file (13.6 kB). View file
 
sampling_info/1400/sim.pdf ADDED
Binary file (11.5 kB). View file
 
sampling_info/1500/load.pdf ADDED
Binary file (12.6 kB). View file
 
sampling_info/1500/prob_map.pdf ADDED
Binary file (13.6 kB). View file
 
sampling_info/1500/sim.pdf ADDED
Binary file (11.5 kB). View file
 
sampling_info/1600/load.pdf ADDED
Binary file (12.6 kB). View file
 
sampling_info/1600/prob_map.pdf ADDED
Binary file (13.6 kB). View file
 
sampling_info/1600/sim.pdf ADDED
Binary file (11.5 kB). View file
 
sampling_info/1700/load.pdf ADDED
Binary file (12.6 kB). View file
 
sampling_info/1700/prob_map.pdf ADDED
Binary file (13.6 kB). View file
 
sampling_info/1700/sim.pdf ADDED
Binary file (11.5 kB). View file
 
sampling_info/1800/load.pdf ADDED
Binary file (12.6 kB). View file
 
sampling_info/1800/prob_map.pdf ADDED
Binary file (13.6 kB). View file
 
sampling_info/1800/sim.pdf ADDED
Binary file (11.5 kB). View file
 
sampling_info/1900/load.pdf ADDED
Binary file (12.6 kB). View file
 
sampling_info/1900/prob_map.pdf ADDED
Binary file (13.6 kB). View file
 
sampling_info/1900/sim.pdf ADDED
Binary file (11.5 kB). View file
 
sampling_info/200/load.pdf ADDED
Binary file (12.6 kB). View file
 
sampling_info/200/prob_map.pdf ADDED
Binary file (13.6 kB). View file
 
sampling_info/200/sim.pdf ADDED
Binary file (11.5 kB). View file
 
sampling_info/2000/load.pdf ADDED
Binary file (12.6 kB). View file
 
sampling_info/2000/prob_map.pdf ADDED
Binary file (13.6 kB). View file
 
sampling_info/2000/sim.pdf ADDED
Binary file (11.5 kB). View file
 
sampling_info/300/load.pdf ADDED
Binary file (12.6 kB). View file