Crystalcareai commited on
Commit
68b3eda
1 Parent(s): 6dc0ddc

Update modeling_gemmoe.py

Browse files
Files changed (1) hide show
  1. modeling_gemmoe.py +1 -1
modeling_gemmoe.py CHANGED
@@ -711,7 +711,7 @@ class GemmoeSparseMoeBlock(nn.Module):
711
  for i in range(self.num_experts):
712
  expert = self.experts[i]
713
  expert_output = expert(hidden_states[flat_topk_idx == i])
714
- y[flat_topk_idx == i] = expert_output
715
 
716
  y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
717
 
 
711
  for i in range(self.num_experts):
712
  expert = self.experts[i]
713
  expert_output = expert(hidden_states[flat_topk_idx == i])
714
+ y[flat_topk_idx == i] = expert_output.to(y.dtype) # Cast expert_output to the same dtype as y
715
 
716
  y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
717