bjoernp commited on
Commit
3689565
1 Parent(s): b694477

Update modeling_moe_mistral.py

Browse files
Files changed (1) hide show
  1. modeling_moe_mistral.py +3 -4
modeling_moe_mistral.py CHANGED
@@ -196,9 +196,7 @@ class FeedForward(nn.Module):
196
  )
197
 
198
  def forward(self, x):
199
- device = x.device
200
- x = x.to(self.w1.weight.device)
201
- return self.w2(F.silu(self.w1(x)) * self.w3(x)).to(device)
202
 
203
 
204
  class MoE(nn.Module):
@@ -217,8 +215,9 @@ class MoE(nn.Module):
217
  orig_shape = x.shape
218
  x = x.view(-1, x.shape[-1])
219
 
220
- scores = self.gate(x).softmax(dim=-1)
221
  expert_weights, expert_indices = torch.topk(scores, self.num_experts_per_token, dim=-1)
 
222
  flat_expert_indices = expert_indices.view(-1)
223
 
224
  x = x.repeat_interleave(self.num_experts_per_token, dim=0)
 
196
  )
197
 
198
  def forward(self, x):
199
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
 
 
200
 
201
 
202
  class MoE(nn.Module):
 
215
  orig_shape = x.shape
216
  x = x.view(-1, x.shape[-1])
217
 
218
+ scores = self.gate(x)
219
  expert_weights, expert_indices = torch.topk(scores, self.num_experts_per_token, dim=-1)
220
+ expert_weights = expert_weights.softmax(dim=-1)
221
  flat_expert_indices = expert_indices.view(-1)
222
 
223
  x = x.repeat_interleave(self.num_experts_per_token, dim=0)