Crystalcareai
commited on
Update modeling_gemmoe.py
Browse files- modeling_gemmoe.py +4 -1
modeling_gemmoe.py
CHANGED
@@ -669,7 +669,7 @@ class GemmoeSparseMoeBlock(nn.Module):
|
|
669 |
|
670 |
self.experts = nn.ModuleList([GemmoeBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
|
671 |
|
672 |
-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
673 |
hidden_states = hidden_states.to(self.gate.weight.device)
|
674 |
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
675 |
hidden_states = hidden_states.view(-1, hidden_dim)
|
@@ -707,6 +707,9 @@ class GemmoeSparseMoeBlock(nn.Module):
|
|
707 |
|
708 |
final_hidden_states.index_add_(0, token_indices, current_hidden_states)
|
709 |
|
|
|
|
|
|
|
710 |
|
711 |
class GemmoeDecoderLayer(nn.Module):
|
712 |
def __init__(self, config: GemmoeConfig, layer_idx: int):
|
|
|
669 |
|
670 |
self.experts = nn.ModuleList([GemmoeBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
|
671 |
|
672 |
+
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
673 |
hidden_states = hidden_states.to(self.gate.weight.device)
|
674 |
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
675 |
hidden_states = hidden_states.view(-1, hidden_dim)
|
|
|
707 |
|
708 |
final_hidden_states.index_add_(0, token_indices, current_hidden_states)
|
709 |
|
710 |
+
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
711 |
+
return final_hidden_states, router_logits
|
712 |
+
|
713 |
|
714 |
class GemmoeDecoderLayer(nn.Module):
|
715 |
def __init__(self, config: GemmoeConfig, layer_idx: int):
|