Crystalcareai commited on
Commit
bad6a57
1 Parent(s): 6f6cbec

Update modeling_gemmoe.py

Browse files
Files changed (1) hide show
  1. modeling_gemmoe.py +14 -11
modeling_gemmoe.py CHANGED
@@ -655,31 +655,34 @@ class GemmoeSparseMoeBlock(nn.Module):
655
  self.num_experts = config.num_local_experts
656
  self.top_k = 2
657
 
 
658
  self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
 
659
  self.experts = nn.ModuleList([GemmoeBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
660
 
661
  def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
662
  batch_size, sequence_length, hidden_dim = hidden_states.shape
663
  hidden_states = hidden_states.view(-1, hidden_dim)
664
 
 
665
  router_logits = self.gate(hidden_states)
666
  routing_weights = F.softmax(router_logits, dim=1)
667
  topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False)
668
  topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
669
 
670
- expert_outputs = []
 
 
 
 
671
  for i in range(self.num_experts):
672
- expert_input = hidden_states[topk_idx[:, i]]
673
- expert_output = self.experts[i](expert_input)
674
- expert_outputs.append(expert_output)
675
 
676
- expert_outputs = torch.stack(expert_outputs, dim=1)
677
- expert_outputs = expert_outputs.view(batch_size, sequence_length, self.top_k, -1)
678
- topk_weight = topk_weight.view(batch_size, sequence_length, self.top_k, 1)
679
-
680
- final_hidden_states = (expert_outputs * topk_weight).sum(dim=2)
681
- final_hidden_states = final_hidden_states.view(batch_size, sequence_length, hidden_dim)
682
-
683
  return final_hidden_states.to(hidden_states.dtype), router_logits.to(hidden_states.dtype)
684
 
685
 
 
655
  self.num_experts = config.num_local_experts
656
  self.top_k = 2
657
 
658
+ # gating
659
  self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
660
+
661
  self.experts = nn.ModuleList([GemmoeBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
662
 
663
  def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
664
  batch_size, sequence_length, hidden_dim = hidden_states.shape
665
  hidden_states = hidden_states.view(-1, hidden_dim)
666
 
667
+ # router_logits: (batch * sequence_length, n_experts)
668
  router_logits = self.gate(hidden_states)
669
  routing_weights = F.softmax(router_logits, dim=1)
670
  topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False)
671
  topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
672
 
673
+ hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
674
+
675
+ y = torch.empty_like(hidden_states)
676
+
677
+ flat_topk_idx = topk_idx.view(-1)
678
  for i in range(self.num_experts):
679
+ expert = self.experts[i]
680
+ expert_output = expert(hidden_states[flat_topk_idx == i])
681
+ y[flat_topk_idx == i] = expert_output
682
 
683
+ y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
684
+
685
+ final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)
 
 
 
 
686
  return final_hidden_states.to(hidden_states.dtype), router_logits.to(hidden_states.dtype)
687
 
688