Crystalcareai commited on
Commit
e6d7d0e
·
verified ·
1 Parent(s): 8825292

Update modeling_gemmoe.py

Browse files
Files changed (1) hide show
  1. modeling_gemmoe.py +12 -26
modeling_gemmoe.py CHANGED
@@ -670,44 +670,30 @@ class GemmoeSparseMoeBlock(nn.Module):
670
  self.experts = nn.ModuleList([GemmoeBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
671
 
672
  def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
673
- hidden_states = hidden_states.to(self.gate.weight.device)
674
  batch_size, sequence_length, hidden_dim = hidden_states.shape
675
  hidden_states = hidden_states.view(-1, hidden_dim)
676
 
677
  # router_logits: (batch * sequence_length, n_experts)
678
  router_logits = self.gate(hidden_states)
679
- routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float32)
680
- top_routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
681
- top_routing_weights /= top_routing_weights.sum(dim=-1, keepdim=True)
682
 
683
  # we cast back to the input dtype
684
- top_routing_weights = top_routing_weights.to(hidden_states.dtype)
685
 
686
- final_hidden_states = torch.zeros(
687
- (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
688
- )
689
-
690
- # Loop over all available experts in the model and perform the computation on each expert
691
- for expert_idx in range(self.num_experts):
692
- expert_layer = self.experts[expert_idx]
693
- token_indices = (selected_experts == expert_idx).any(dim=-1).nonzero(as_tuple=True)[0]
694
-
695
- if token_indices.numel() == 0:
696
- continue
697
-
698
- current_state = hidden_states[token_indices]
699
- current_hidden_states = expert_layer(current_state)
700
 
701
- # Multiply the output hidden states by `top_routing_weights` on the corresponding tokens
702
- expert_indices = (selected_experts[token_indices] == expert_idx).nonzero(as_tuple=True)[1]
703
- current_hidden_states *= top_routing_weights[token_indices, expert_indices, None]
704
 
705
- # Cast current_hidden_states to the same data type as final_hidden_states
706
- current_hidden_states = current_hidden_states.to(final_hidden_states.dtype)
 
 
707
 
708
- final_hidden_states.index_add_(0, token_indices, current_hidden_states)
709
 
710
- final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
711
  return final_hidden_states, router_logits
712
 
713
 
 
670
  self.experts = nn.ModuleList([GemmoeBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
671
 
672
  def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
 
673
  batch_size, sequence_length, hidden_dim = hidden_states.shape
674
  hidden_states = hidden_states.view(-1, hidden_dim)
675
 
676
  # router_logits: (batch * sequence_length, n_experts)
677
  router_logits = self.gate(hidden_states)
678
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
679
+ topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False)
680
+ topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
681
 
682
  # we cast back to the input dtype
683
+ topk_weight = topk_weight.to(hidden_states.dtype)
684
 
685
+ hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
686
 
687
+ y = torch.empty_like(hidden_states)
 
 
688
 
689
+ flat_topk_idx = topk_idx.view(-1)
690
+ for i in range(self.num_experts):
691
+ expert = self.experts[i]
692
+ y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
693
 
694
+ y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
695
 
696
+ final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)
697
  return final_hidden_states, router_logits
698
 
699