marcusguhao commited on
Commit
cada6ff
·
verified ·
1 Parent(s): 4c8b97d

Upload baichuan_moe.py

Browse files
Files changed (1) hide show
  1. baichuan_moe.py +683 -0
baichuan_moe.py ADDED
@@ -0,0 +1,683 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 2024 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved.
2
+ import time
3
+ """Inference-only Baichuan-MOE model."""
4
+ from transformers.configuration_utils import PretrainedConfig
5
+ class BaiChuanMoEConfig(PretrainedConfig):
6
+ model_type = "baichuan-moe"
7
+ keys_to_ignore_at_inference = ["past_key_values"]
8
+
9
+ def __init__(
10
+ self,
11
+ vocab_size=64000,
12
+ hidden_size=4096,
13
+ intermediate_size=11008,
14
+ num_hidden_layers=32,
15
+ num_attention_heads=32,
16
+ hidden_act="silu",
17
+ max_position_embeddings=4096,
18
+ initializer_range=0.02,
19
+ rms_norm_eps=1e-6,
20
+ rope_base=1e6,
21
+ use_cache=True,
22
+ pad_token_id=0,
23
+ bos_token_id=1,
24
+ eos_token_id=2,
25
+ tie_word_embeddings=False,
26
+ moe_experts_fixed=0,
27
+ moe_experts_selected=2,
28
+ moe_experts_routed=8,
29
+ num_experts_fixed_per_layer=None, # "0,0,0,1,0,2"
30
+ num_experts_selected_per_layer=None, # "1,2,1,1,1,2"
31
+ num_experts_routed_per_layer=None, # "1,8,1,8,1,16"
32
+ **kwargs,
33
+ ):
34
+ self.vocab_size = vocab_size
35
+ self.max_position_embeddings = max_position_embeddings
36
+ self.hidden_size = hidden_size
37
+ self.intermediate_size = intermediate_size
38
+ self.num_hidden_layers = num_hidden_layers
39
+ self.num_attention_heads = num_attention_heads
40
+ self.hidden_act = hidden_act
41
+ self.initializer_range = initializer_range
42
+ self.rms_norm_eps = rms_norm_eps
43
+ self.rope_base = rope_base
44
+ self.use_cache = use_cache
45
+ self.moe_experts_fixed = moe_experts_fixed
46
+ self.moe_experts_selected = moe_experts_selected
47
+ self.moe_experts_routed = moe_experts_routed
48
+ if num_experts_routed_per_layer:
49
+ self.num_experts_routed_per_layer = [int(_.strip()) for _ in num_experts_routed_per_layer.split(",")]
50
+ assert len(self.num_experts_routed_per_layer) == self.num_hidden_layers
51
+ assert all([_ >= 1 for _ in self.num_experts_routed_per_layer])
52
+ else:
53
+ self.num_experts_routed_per_layer = None
54
+
55
+ if num_experts_selected_per_layer:
56
+ self.num_experts_selected_per_layer = [int(_.strip()) for _ in num_experts_selected_per_layer.split(",")]
57
+ assert len(self.num_experts_selected_per_layer) == self.num_hidden_layers
58
+ assert all([x >= y for x, y in zip(self.num_experts_routed_per_layer, self.num_experts_selected_per_layer)])
59
+ else:
60
+ self.num_experts_selected_per_layer = None
61
+
62
+ if num_experts_fixed_per_layer:
63
+ self.num_experts_fixed_per_layer = [int(_.strip()) for _ in num_experts_fixed_per_layer.split(",")]
64
+ assert len(self.num_experts_fixed_per_layer) == self.num_hidden_layers
65
+ else:
66
+ self.num_experts_fixed_per_layer = None
67
+
68
+ super().__init__(
69
+ pad_token_id=pad_token_id,
70
+ bos_token_id=bos_token_id,
71
+ eos_token_id=eos_token_id,
72
+ tie_word_embeddings=tie_word_embeddings,
73
+ **kwargs,
74
+ )
75
+ import copy
76
+ import math
77
+ from typing import Iterable, List, Optional, Tuple, Union
78
+
79
+ import torch
80
+ from torch import nn
81
+ from transformers import PretrainedConfig
82
+
83
+ from vllm.attention import Attention, AttentionMetadata
84
+ from vllm.config import CacheConfig, LoRAConfig
85
+ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
86
+ get_tensor_model_parallel_world_size,
87
+ tensor_model_parallel_all_reduce)
88
+ from vllm.model_executor.layers.fused_moe import fused_moe
89
+
90
+
91
+ from vllm.model_executor.layers.activation import SiluAndMul,GeluAndMul
92
+ from vllm.model_executor.layers.layernorm import RMSNorm
93
+ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
94
+ QKVParallelLinear,
95
+ ReplicatedLinear,
96
+ RowParallelLinear,
97
+ LinearMethodBase)
98
+ from vllm.model_executor.layers.logits_processor import LogitsProcessor
99
+ from vllm.model_executor.layers.quantization import QuantizationConfig
100
+ from vllm.model_executor.layers.rotary_embedding import get_rope
101
+ from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
102
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
103
+ DEFAULT_VOCAB_PADDING_SIZE,ParallelLMHead, VocabParallelEmbedding)
104
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader, maybe_remap_kv_scale_name
105
+ from vllm.model_executor.utils import set_weight_attrs
106
+ from vllm.model_executor.sampling_metadata import SamplingMetadata
107
+ from vllm.sequence import IntermediateTensors
108
+
109
+ from .interfaces import SupportsLoRA, SupportsPP
110
+ from .utils import (is_pp_missing_parameter, PPMissingLayer,
111
+ make_empty_intermediate_tensors_factory, make_layers)
112
+
113
+ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
114
+ closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
115
+ base = torch.tensor(
116
+ 2**(-(2**-(math.log2(closest_power_of_2) - 3))),
117
+ dtype=torch.float32,
118
+ )
119
+ powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
120
+ slopes = torch.pow(base, powers)
121
+
122
+ if closest_power_of_2 != total_num_heads:
123
+ extra_base = torch.tensor(
124
+ 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
125
+ dtype=torch.float32,
126
+ )
127
+ num_remaining_heads = min(closest_power_of_2,
128
+ total_num_heads - closest_power_of_2)
129
+ extra_powers = torch.arange(start=1,
130
+ end=1 + 2 * num_remaining_heads,
131
+ step=2,
132
+ dtype=torch.int32)
133
+ slopes = torch.cat(
134
+ [slopes, torch.pow(extra_base, extra_powers)], dim=0)
135
+ return slopes
136
+
137
+ class MLP(nn.Module):
138
+ def __init__(
139
+ self,
140
+ hidden_size: int,
141
+ intermediate_size: int,
142
+ hidden_act: str,
143
+ layer_index: int
144
+ ):
145
+ super().__init__()
146
+ self.layer_index = layer_index
147
+ self.gate_up_proj = MergedColumnParallelLinear(
148
+ hidden_size, [intermediate_size] * 2,
149
+ bias=False,)
150
+ # linear_method=None)
151
+ self.down_proj = RowParallelLinear(intermediate_size,
152
+ hidden_size,
153
+ bias=False,)
154
+ # linear_method=None)
155
+ if hidden_act not in ["silu", "gelu"]:
156
+ raise ValueError(f"Unsupported activation: {hidden_act}. "
157
+ "Only silu and gelu are supported for now.")
158
+ self.act_fn = SiluAndMul() if hidden_act == "silu" else GeluAndMul()
159
+
160
+ def forward(self, x):
161
+ gate_up, _ = self.gate_up_proj(x)
162
+ x = self.act_fn(gate_up)
163
+ ret, _ = self.down_proj(x)
164
+
165
+ return ret
166
+
167
+
168
+ class MixtralMLP(nn.Module):
169
+ """
170
+ This implementation is
171
+ strictly equivalent to standard MoE with full capacity (no
172
+ dropped tokens). It's faster since it formulates MoE operations
173
+ in terms of block-sparse operations to accomodate imbalanced
174
+ assignments of tokens to experts, whereas standard MoE either
175
+ (1) drop tokens at the cost of reduced performance or (2) set
176
+ capacity factor to number of experts and thus waste computation
177
+ and memory on padding.
178
+ """
179
+
180
+ def __init__(self,
181
+ hidden_size,
182
+ intermediate_size,
183
+ hidden_act,
184
+ moe_experts_routed,
185
+ moe_experts_selected,
186
+ moe_experts_fixed,
187
+ layer_index,
188
+ params_dtype: Optional[torch.dtype] = None,
189
+ tp_size: Optional[int] = None):
190
+ super().__init__()
191
+
192
+ self.layer_index = layer_index
193
+ self.tp_size = tp_size or get_tensor_model_parallel_world_size()
194
+
195
+ self.num_experts_routed = moe_experts_routed
196
+ self.num_local_experts_routed = self.num_experts_routed // 1
197
+ self.top_k = moe_experts_selected
198
+ self.hidden_size = hidden_size
199
+ self.intermediate_size = intermediate_size // self.tp_size
200
+
201
+
202
+ if params_dtype is None:
203
+ params_dtype = torch.get_default_dtype()
204
+ self.params_dtype = params_dtype
205
+ self.router = ReplicatedLinear(self.hidden_size,
206
+ self.num_experts_routed,
207
+ bias=False,
208
+ params_dtype=self.params_dtype,)
209
+ # linear_method=None)
210
+
211
+ self.ws = nn.Parameter(
212
+ torch.empty(self.num_experts_routed,
213
+ 2 * self.intermediate_size,
214
+ self.hidden_size,
215
+ device="cuda",
216
+ dtype=self.params_dtype))
217
+ self.w2s = nn.Parameter(
218
+ torch.empty(self.num_experts_routed,
219
+ self.hidden_size,
220
+ self.intermediate_size,
221
+ device="cuda",
222
+ dtype=self.params_dtype))
223
+
224
+ set_weight_attrs(self.ws, {
225
+ "weight_loader": self.weight_loader,
226
+ })
227
+ set_weight_attrs(self.w2s, {
228
+ "weight_loader": self.weight_loader,
229
+ })
230
+
231
+
232
+ if moe_experts_fixed >= 1:
233
+ self.local_experts_fixed = MLP(hidden_size, intermediate_size*moe_experts_fixed, hidden_act, layer_index)
234
+ else:
235
+ self.local_experts_fixed = None
236
+
237
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
238
+ weight_name: str, expert_id: int):
239
+ tp_rank = get_tensor_model_parallel_rank()
240
+ param_data = param.data
241
+ shard_size = self.intermediate_size
242
+ shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
243
+ if weight_name.endswith("gate_proj.weight"):
244
+ param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
245
+ if weight_name.endswith("up_proj.weight"):
246
+ param_data[expert_id,
247
+ shard_size:2 * shard_size, :] = loaded_weight[shard, :]
248
+ if weight_name.endswith("down_proj.weight"):
249
+ param_data[expert_id, :, :] = loaded_weight[:, shard]
250
+
251
+
252
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
253
+ """ """
254
+ num_tokens, hidden_size = hidden_states.shape
255
+ hidden_states = hidden_states.view(-1, self.hidden_size)
256
+ router_logits, _ = self.router(hidden_states)
257
+ final_hidden_states = fused_moe(hidden_states,
258
+ self.ws,
259
+ self.w2s,
260
+ router_logits,
261
+ self.top_k,
262
+ renormalize=True)
263
+
264
+ if self.tp_size > 1:
265
+ final_hidden_states = tensor_model_parallel_all_reduce(
266
+ final_hidden_states)
267
+
268
+ final_hidden_states = final_hidden_states.view(num_tokens, hidden_size)
269
+
270
+ if self.local_experts_fixed:
271
+ final_hidden_states += self.local_experts_fixed(hidden_states).reshape(num_tokens, hidden_size)
272
+ final_hidden_states /= 2
273
+
274
+ ret = final_hidden_states.reshape(num_tokens, hidden_size)
275
+ return ret
276
+
277
+
278
+ class MixtralAttention(nn.Module):
279
+
280
+ def __init__(self,
281
+ hidden_size: int,
282
+ num_heads: int,
283
+ num_kv_heads: int,
284
+ max_position: int = 4096 * 32,
285
+ rope_theta: float = 10000,
286
+ linear_method: Optional[LinearMethodBase] = None,) -> None:
287
+ super().__init__()
288
+ self.hidden_size = hidden_size
289
+ tp_size = get_tensor_model_parallel_world_size()
290
+ self.total_num_heads = num_heads
291
+ assert self.total_num_heads % tp_size == 0
292
+ self.num_heads = self.total_num_heads // tp_size
293
+ self.total_num_kv_heads = num_kv_heads
294
+ if self.total_num_kv_heads >= tp_size:
295
+ # Number of KV heads is greater than TP size, so we partition
296
+ # the KV heads across multiple tensor parallel GPUs.
297
+ assert self.total_num_kv_heads % tp_size == 0
298
+ else:
299
+ # Number of KV heads is less than TP size, so we replicate
300
+ # the KV heads across multiple tensor parallel GPUs.
301
+ assert tp_size % self.total_num_kv_heads == 0
302
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
303
+ self.head_dim = hidden_size // self.total_num_heads
304
+ self.q_size = self.num_heads * self.head_dim
305
+ self.kv_size = self.num_kv_heads * self.head_dim
306
+ self.scaling = self.head_dim**-0.5
307
+ self.rope_theta = rope_theta
308
+
309
+ self.W_pack = QKVParallelLinear(
310
+ hidden_size,
311
+ self.head_dim,
312
+ self.total_num_heads,
313
+ self.total_num_kv_heads,
314
+ bias=False,
315
+ )
316
+ self.o_proj = RowParallelLinear(
317
+ self.total_num_heads * self.head_dim,
318
+ hidden_size,
319
+ bias=False,
320
+ )
321
+ self.rotary_emb = get_rope(
322
+ self.head_dim,
323
+ rotary_dim=self.head_dim,
324
+ max_position=max_position,
325
+ base=int(self.rope_theta),
326
+ is_neox_style=True,
327
+ )
328
+ self.attn = Attention(
329
+ self.num_heads,
330
+ self.head_dim,
331
+ self.scaling,
332
+ num_kv_heads=self.num_kv_heads,
333
+ )
334
+
335
+ def forward(
336
+ self,
337
+ positions: torch.Tensor,
338
+ hidden_states: torch.Tensor,
339
+ kv_cache: torch.Tensor,
340
+ attn_metadata: AttentionMetadata,
341
+ ) -> torch.Tensor:
342
+ qkv, _ = self.W_pack(hidden_states)
343
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
344
+ q, k = self.rotary_emb(positions, q, k)
345
+ attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
346
+ output, _ = self.o_proj(attn_output)
347
+ return output
348
+
349
+
350
+ class DecoderLayer(nn.Module):
351
+ def __init__(
352
+ self,
353
+ config: BaiChuanMoEConfig,
354
+ linear_method: Optional[LinearMethodBase] = None,
355
+ layer_index: Optional[int] = 1
356
+ ) -> None:
357
+ super().__init__()
358
+
359
+ self.layer_index = layer_index
360
+ self.hidden_size = config.hidden_size
361
+ # Requires transformers > 4.32.0
362
+ rope_theta = getattr(config, "rope_base", 10000)
363
+ self.self_attn = MixtralAttention(
364
+ hidden_size=self.hidden_size,
365
+ num_heads=config.num_attention_heads,
366
+ max_position=config.max_position_embeddings,
367
+ num_kv_heads=config.num_attention_heads,
368
+ rope_theta=rope_theta,
369
+ linear_method=linear_method)
370
+
371
+
372
+ # Dense
373
+ if config.moe_experts_routed == 1:
374
+ self.mlp = MLP(hidden_size=config.hidden_size,
375
+ intermediate_size=config.intermediate_size,
376
+ hidden_act=config.hidden_act, layer_index=layer_index)
377
+ # MoE
378
+ else:
379
+ self.mlp = MixtralMLP(config.hidden_size,
380
+ config.intermediate_size,
381
+ config.hidden_act,
382
+ config.moe_experts_routed,
383
+ config.moe_experts_selected,
384
+ config.moe_experts_fixed,
385
+ layer_index)
386
+ self.input_layernorm = RMSNorm(config.hidden_size,
387
+ eps=config.rms_norm_eps)
388
+ self.post_attention_layernorm = RMSNorm(config.hidden_size,
389
+ eps=config.rms_norm_eps)
390
+
391
+ def forward(
392
+ self,
393
+ positions: torch.Tensor,
394
+ hidden_states: torch.Tensor,
395
+ kv_cache: torch.Tensor,
396
+ attn_metadata: AttentionMetadata,
397
+ residual: Optional[torch.Tensor],
398
+ ) -> torch.Tensor:
399
+
400
+ # Self Attention
401
+ if residual is None:
402
+ residual = hidden_states
403
+ hidden_states = self.input_layernorm(hidden_states)
404
+ else:
405
+ hidden_states, residual = self.input_layernorm(
406
+ hidden_states, residual)
407
+ hidden_states = self.self_attn(
408
+ positions=positions,
409
+ hidden_states=hidden_states,
410
+ kv_cache=kv_cache,
411
+ attn_metadata=attn_metadata,
412
+ )
413
+
414
+ # Fully Connected
415
+ hidden_states, residual = self.post_attention_layernorm(
416
+ hidden_states, residual)
417
+
418
+ hidden_states = self.mlp(hidden_states)
419
+
420
+ return hidden_states, residual
421
+
422
+
423
+ class Model(nn.Module):
424
+ def __init__(
425
+ self,
426
+ config: BaiChuanMoEConfig,
427
+ linear_method: Optional[LinearMethodBase] = None,
428
+ lora_config: Optional[LoRAConfig] = None,
429
+ ) -> None:
430
+ super().__init__()
431
+ self.padding_idx = config.pad_token_id
432
+ lora_vocab = (lora_config.lora_extra_vocab_size *
433
+ (lora_config.max_loras or 1)) if lora_config else 0
434
+ self.vocab_size = config.vocab_size + lora_vocab
435
+ self.org_vocab_size = config.vocab_size
436
+
437
+ self.embed_tokens = VocabParallelEmbedding(
438
+ self.vocab_size,
439
+ config.hidden_size,
440
+ org_num_embeddings=config.vocab_size,
441
+ )
442
+
443
+ if config.num_experts_routed_per_layer:
444
+ layers = []
445
+ for index in range(config.num_hidden_layers):
446
+ config_ = copy.deepcopy(config)
447
+ config_.moe_experts_fixed = config.num_experts_fixed_per_layer[index]
448
+ config_.moe_experts_selected = config.num_experts_selected_per_layer[index]
449
+ config_.moe_experts_routed = config.num_experts_routed_per_layer[index]
450
+
451
+ layers.append(DecoderLayer(config_, linear_method=None, layer_index=index))
452
+ self.layers = nn.ModuleList(layers)
453
+ else:
454
+ self.layers = nn.ModuleList([
455
+ DecoderLayer(config, linear_method=linear_method)
456
+ for _ in range(config.num_hidden_layers)
457
+ ])
458
+
459
+
460
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
461
+ self.make_empty_intermediate_tensors = (
462
+ make_empty_intermediate_tensors_factory(
463
+ ["hidden_states", "residual"], config.hidden_size))
464
+
465
+ def forward(
466
+ self,
467
+ input_ids: torch.Tensor,
468
+ positions: torch.Tensor,
469
+ kv_caches: List[torch.Tensor],
470
+ attn_metadata: AttentionMetadata,
471
+ intermediate_tensors: Optional[IntermediateTensors],
472
+ ) -> torch.Tensor:
473
+ if get_pp_group().is_first_rank:
474
+ hidden_states = self.embed_tokens(input_ids)
475
+ residual = None
476
+ else:
477
+ assert intermediate_tensors is not None
478
+ hidden_states = intermediate_tensors["hidden_states"]
479
+ residual = intermediate_tensors["residual"]
480
+
481
+ hidden_states = self.embed_tokens(input_ids)
482
+ residual = None
483
+
484
+ for idx, decoder_layer in enumerate(self.layers):
485
+ hidden_states, residual = decoder_layer(positions, hidden_states,
486
+ kv_caches[idx], attn_metadata,
487
+ residual)
488
+ if not get_pp_group().is_last_rank:
489
+ return IntermediateTensors({
490
+ "hidden_states": hidden_states,
491
+ "residual": residual
492
+ })
493
+ hidden_states, _ = self.norm(hidden_states, residual)
494
+
495
+ return hidden_states
496
+
497
+ class NormHead(nn.Module):
498
+ def __init__(self, hidden_size, vocab_size, bias=False):
499
+ super().__init__()
500
+ self.weight = nn.Parameter(torch.empty((vocab_size, hidden_size)))
501
+ nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
502
+ self.norm_weight = nn.functional.normalize(self.weight)
503
+
504
+ def forward(self, hidden_states):
505
+ return nn.functional.linear(hidden_states, self.norm_weight)
506
+
507
+ class BaiChuanMoEForCausalLM(nn.Module):
508
+ # packed_modules_mapping = {
509
+ # "qkv_proj": [
510
+ # "q_proj",
511
+ # "k_proj",
512
+ # "v_proj",
513
+ # ],
514
+ # }
515
+
516
+ # # LoRA specific attributes
517
+ # supported_lora_modules = [
518
+ # "qkv_proj",
519
+ # "o_proj",
520
+ # "embed_tokens",
521
+ # "lm_head",
522
+ # ]
523
+ embedding_modules = {
524
+ "embed_tokens": "input_embeddings",
525
+ "lm_head": "output_embeddings",
526
+ }
527
+ embedding_padding_modules = ["lm_head"]
528
+
529
+ def __init__(
530
+ self,
531
+ config: BaiChuanMoEConfig,
532
+ linear_method: Optional[LinearMethodBase] = None,
533
+ cache_config: Optional[CacheConfig] = None,
534
+ quant_config: Optional[QuantizationConfig] = None,
535
+ lora_config: Optional[LoRAConfig] = None,
536
+ ) -> None:
537
+ super().__init__()
538
+ self.config = config
539
+ self.linear_method = linear_method
540
+ self.model = Model(config,
541
+ linear_method,
542
+ lora_config=lora_config)
543
+ # if get_pp_group().is_last_rank:
544
+ # self.unpadded_vocab_size = config.vocab_size
545
+ # if lora_config:
546
+ # self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
547
+ # self.lm_head = ParallelLMHead(
548
+ # self.unpadded_vocab_size,
549
+ # config.hidden_size,
550
+ # org_num_embeddings=config.vocab_size,
551
+ # padding_size=(
552
+ # DEFAULT_VOCAB_PADDING_SIZE
553
+ # # We need bigger padding if using lora for kernel
554
+ # # compatibility
555
+ # if not lora_config else
556
+ # lora_config.lora_vocab_padding_size),
557
+ # quant_config=quant_config,
558
+ # )
559
+ # if config.tie_word_embeddings:
560
+ # self.lm_head = self.lm_head.tie_weights(
561
+ # self.model.embed_tokens)
562
+
563
+ # logit_scale = getattr(config, "logit_scale", 1.0)
564
+ # self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
565
+ # config.vocab_size,
566
+ # logit_scale)
567
+ # self.sampler = Sampler()
568
+ # else:
569
+ # self.lm_head = PPMissingLayer()
570
+ self.unpadded_vocab_size = config.vocab_size
571
+ if lora_config:
572
+ self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
573
+ self.lm_head = ParallelLMHead(
574
+ self.unpadded_vocab_size,
575
+ config.hidden_size,
576
+ org_num_embeddings=config.vocab_size,
577
+ padding_size=DEFAULT_VOCAB_PADDING_SIZE
578
+ # We need bigger padding if using lora for kernel
579
+ # compatibility
580
+ if not lora_config else lora_config.lora_vocab_padding_size,
581
+ quant_config=quant_config,
582
+ )
583
+ if self.config.tie_word_embeddings:
584
+ self.lm_head.weight = self.model.embed_tokens.weight
585
+ self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
586
+ config.vocab_size)
587
+ self.sampler = Sampler()
588
+ self.make_empty_intermediate_tensors = (
589
+ self.model.make_empty_intermediate_tensors)
590
+
591
+ def forward(
592
+ self,
593
+ input_ids: torch.Tensor,
594
+ positions: torch.Tensor,
595
+ kv_caches: List[torch.Tensor],
596
+ attn_metadata: AttentionMetadata,
597
+ intermediate_tensors: Optional[IntermediateTensors] = None,
598
+ ) -> torch.Tensor:
599
+ hidden_states = self.model(input_ids, positions, kv_caches,
600
+ attn_metadata, intermediate_tensors)
601
+ return hidden_states
602
+
603
+ def compute_logits(self, hidden_states: torch.Tensor,
604
+ sampling_metadata: SamplingMetadata) -> torch.Tensor:
605
+ logits = self.logits_processor(self.lm_head, hidden_states,
606
+ sampling_metadata)
607
+ return logits
608
+
609
+ def sample(
610
+ self,
611
+ logits: Optional[torch.Tensor],
612
+ sampling_metadata: SamplingMetadata,
613
+ ) -> Optional[SamplerOutput]:
614
+ next_tokens = self.sampler(logits, sampling_metadata)
615
+ return next_tokens
616
+
617
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
618
+ stacked_params_mapping = [
619
+ # (param_name, shard_name, shard_id)
620
+ ("qkv_proj", "q_proj", "q"),
621
+ ("qkv_proj", "k_proj", "k"),
622
+ ("qkv_proj", "v_proj", "v"),
623
+ ("mlp.gate_up_proj", "mlp.gate_proj", 0),
624
+ ("mlp.gate_up_proj", "mlp.up_proj", 1),
625
+ ("mlp.local_experts_fixed.gate_up_proj", "mlp.local_experts_fixed.gate_proj", 0),
626
+ ("mlp.local_experts_fixed.gate_up_proj", "mlp.local_experts_fixed.up_proj", 1),
627
+ ]
628
+
629
+ expert_params_mapping = [
630
+ # (param_name, weight_name, expert_id)
631
+ ("ws" if weight_name in ["gate_proj", "up_proj"] else "w2s",
632
+ f"local_experts_routed.{expert_id}.{weight_name}.weight", expert_id)
633
+ for expert_id in range(16)
634
+ for weight_name in ["gate_proj", "down_proj", "up_proj"]
635
+ ]
636
+
637
+ params_dict = dict(self.named_parameters())
638
+
639
+ for name, loaded_weight in weights:
640
+ if "rotary_emb.inv_freq" in name:
641
+ continue
642
+
643
+ for (param_name, weight_name, shard_id) in stacked_params_mapping:
644
+ if weight_name not in name:
645
+ continue
646
+ name = name.replace(weight_name, param_name)
647
+ # Skip loading extra bias for GPTQ models.
648
+ if name.endswith(".bias") and name not in params_dict:
649
+ continue
650
+ param = params_dict[name]
651
+ weight_loader = param.weight_loader
652
+ weight_loader(param, loaded_weight, shard_id)
653
+ break
654
+ else:
655
+ for param_name, weight_name, expert_id in expert_params_mapping:
656
+ if weight_name not in name:
657
+ continue
658
+ name = name.replace(weight_name, param_name)
659
+ param = params_dict[name]
660
+ weight_loader = param.weight_loader
661
+ weight_loader(param,
662
+ loaded_weight,
663
+ weight_name,
664
+ expert_id=expert_id)
665
+ break
666
+ else:
667
+ # Skip loading extra bias for GPTQ models.
668
+ if name.endswith(".bias") and name not in params_dict:
669
+ continue
670
+
671
+ param = params_dict.get(name, None)
672
+
673
+ if name == "lm_head.weight":
674
+ # do norm
675
+ norm_weight = nn.functional.normalize(loaded_weight)
676
+ weight_loader = getattr(param, "weight_loader",
677
+ default_weight_loader)
678
+ weight_loader(param, norm_weight)
679
+ else:
680
+ weight_loader = getattr(param, "weight_loader",
681
+ default_weight_loader)
682
+ weight_loader(param, loaded_weight)
683
+ # 'model.layers.0.mlp.down_proj.weight_packed'