mashirong
commited on
Commit
•
8586474
1
Parent(s):
ef87e36
Update modeling_deepseek.py
Browse files- modeling_deepseek.py +9 -3
modeling_deepseek.py
CHANGED
@@ -552,7 +552,9 @@ class DeepseekV2MoE(nn.Module):
|
|
552 |
self.ep_rank = 0
|
553 |
self.experts = nn.ModuleList(
|
554 |
[
|
555 |
-
DeepseekV2MLP(
|
|
|
|
|
556 |
for i in range(config.n_routed_experts)
|
557 |
]
|
558 |
)
|
@@ -577,7 +579,7 @@ class DeepseekV2MoE(nn.Module):
|
|
577 |
for i, expert in enumerate(self.experts):
|
578 |
y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
|
579 |
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
580 |
-
y = y.view(*orig_shape)
|
581 |
y = AddAuxiliaryLoss.apply(y, aux_loss)
|
582 |
else:
|
583 |
y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape)
|
@@ -1023,7 +1025,11 @@ class DeepseekV2FlashAttention2(DeepseekV2Attention):
|
|
1023 |
elif torch.is_autocast_enabled():
|
1024 |
target_dtype = torch.get_autocast_gpu_dtype()
|
1025 |
else:
|
1026 |
-
target_dtype =
|
|
|
|
|
|
|
|
|
1027 |
|
1028 |
logger.warning_once(
|
1029 |
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
|
|
552 |
self.ep_rank = 0
|
553 |
self.experts = nn.ModuleList(
|
554 |
[
|
555 |
+
DeepseekV2MLP(
|
556 |
+
config, intermediate_size=config.moe_intermediate_size
|
557 |
+
)
|
558 |
for i in range(config.n_routed_experts)
|
559 |
]
|
560 |
)
|
|
|
579 |
for i, expert in enumerate(self.experts):
|
580 |
y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
|
581 |
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
582 |
+
y = y.to(hidden_states.dtype).view(*orig_shape)
|
583 |
y = AddAuxiliaryLoss.apply(y, aux_loss)
|
584 |
else:
|
585 |
y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape)
|
|
|
1025 |
elif torch.is_autocast_enabled():
|
1026 |
target_dtype = torch.get_autocast_gpu_dtype()
|
1027 |
else:
|
1028 |
+
target_dtype = (
|
1029 |
+
self.q_proj.weight.dtype
|
1030 |
+
if self.q_lora_rank is None
|
1031 |
+
else self.q_a_proj.weight.dtype
|
1032 |
+
)
|
1033 |
|
1034 |
logger.warning_once(
|
1035 |
f"The input hidden states seems to be silently casted in float32, this might be related to"
|