mashirong commited on
Commit
8586474
1 Parent(s): ef87e36

Update modeling_deepseek.py

Browse files
Files changed (1) hide show
  1. modeling_deepseek.py +9 -3
modeling_deepseek.py CHANGED
@@ -552,7 +552,9 @@ class DeepseekV2MoE(nn.Module):
552
  self.ep_rank = 0
553
  self.experts = nn.ModuleList(
554
  [
555
- DeepseekV2MLP(config, intermediate_size=config.moe_intermediate_size)
 
 
556
  for i in range(config.n_routed_experts)
557
  ]
558
  )
@@ -577,7 +579,7 @@ class DeepseekV2MoE(nn.Module):
577
  for i, expert in enumerate(self.experts):
578
  y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
579
  y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
580
- y = y.view(*orig_shape)
581
  y = AddAuxiliaryLoss.apply(y, aux_loss)
582
  else:
583
  y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape)
@@ -1023,7 +1025,11 @@ class DeepseekV2FlashAttention2(DeepseekV2Attention):
1023
  elif torch.is_autocast_enabled():
1024
  target_dtype = torch.get_autocast_gpu_dtype()
1025
  else:
1026
- target_dtype = self.q_proj.weight.dtype if self.q_lora_rank is None else self.q_a_proj.weight.dtype
 
 
 
 
1027
 
1028
  logger.warning_once(
1029
  f"The input hidden states seems to be silently casted in float32, this might be related to"
 
552
  self.ep_rank = 0
553
  self.experts = nn.ModuleList(
554
  [
555
+ DeepseekV2MLP(
556
+ config, intermediate_size=config.moe_intermediate_size
557
+ )
558
  for i in range(config.n_routed_experts)
559
  ]
560
  )
 
579
  for i, expert in enumerate(self.experts):
580
  y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
581
  y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
582
+ y = y.to(hidden_states.dtype).view(*orig_shape)
583
  y = AddAuxiliaryLoss.apply(y, aux_loss)
584
  else:
585
  y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape)
 
1025
  elif torch.is_autocast_enabled():
1026
  target_dtype = torch.get_autocast_gpu_dtype()
1027
  else:
1028
+ target_dtype = (
1029
+ self.q_proj.weight.dtype
1030
+ if self.q_lora_rank is None
1031
+ else self.q_a_proj.weight.dtype
1032
+ )
1033
 
1034
  logger.warning_once(
1035
  f"The input hidden states seems to be silently casted in float32, this might be related to"