bjoernp commited on
Commit
2989f91
1 Parent(s): 9b9979d

Update modeling_moe_mistral.py

Browse files
Files changed (1) hide show
  1. modeling_moe_mistral.py +4 -4
modeling_moe_mistral.py CHANGED
@@ -220,11 +220,11 @@ class MoE(nn.Module):
220
  flat_expert_indices = expert_indices.view(-1)
221
 
222
  x = x.repeat_interleave(self.num_experts_per_token, dim=0)
223
- x = torch.empty_like(x)
224
  for i, expert in enumerate(self.experts):
225
- x[flat_expert_indices == i] = expert(x[flat_expert_indices == i])
226
- x = (x.view(*expert_weights.shape, -1) * expert_weights.unsqueeze(-1)).sum(dim=1)
227
- return x.view(*orig_shape)
228
 
229
 
230
  # Copied from transformers.models.llama.modeling_llama.repeat_kv
 
220
  flat_expert_indices = expert_indices.view(-1)
221
 
222
  x = x.repeat_interleave(self.num_experts_per_token, dim=0)
223
+ y = torch.empty_like(x)
224
  for i, expert in enumerate(self.experts):
225
+ y[flat_expert_indices == i] = expert(y[flat_expert_indices == i])
226
+ y = (y.view(*expert_weights.shape, -1) * expert_weights.unsqueeze(-1)).sum(dim=1)
227
+ return y.view(*orig_shape)
228
 
229
 
230
  # Copied from transformers.models.llama.modeling_llama.repeat_kv