Update modeling_grok.py
Browse files- modeling_grok.py +4 -4
modeling_grok.py
CHANGED
@@ -84,7 +84,7 @@ class GrokRMSNorm(nn.Module):
|
|
84 |
GrokRMSNorm is equivalent to T5LayerNorm
|
85 |
"""
|
86 |
super().__init__()
|
87 |
-
self.weight = nn.Parameter(torch.ones(hidden_size))
|
88 |
self.variance_epsilon = eps
|
89 |
|
90 |
def forward(self, hidden_states):
|
@@ -92,7 +92,7 @@ class GrokRMSNorm(nn.Module):
|
|
92 |
hidden_states = hidden_states.to(torch.float32)
|
93 |
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
94 |
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
95 |
-
return self.weight * hidden_states.to(input_dtype)
|
96 |
|
97 |
|
98 |
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Grok
|
@@ -338,7 +338,7 @@ class GrokDecoderLayer(nn.Module):
|
|
338 |
self.top_k = config.num_experts_per_tok
|
339 |
|
340 |
self.multi_head_attention = GrokAttention(config, layer_idx)
|
341 |
-
self.router = nn.Linear(self.hidden_size, self.num_experts, bias=False)
|
342 |
self.moe = nn.ModuleList([GrokBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
|
343 |
|
344 |
self.rms_norm = GrokRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
@@ -400,7 +400,7 @@ class GrokDecoderLayer(nn.Module):
|
|
400 |
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
401 |
hidden_states = hidden_states.view(-1, hidden_dim)
|
402 |
# router_logits: (batch * sequence_length, n_experts)
|
403 |
-
router_logits = self.router(hidden_states)
|
404 |
|
405 |
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
406 |
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
|
|
84 |
GrokRMSNorm is equivalent to T5LayerNorm
|
85 |
"""
|
86 |
super().__init__()
|
87 |
+
self.weight = nn.Parameter(torch.ones(hidden_size, dtype=torch.float32))
|
88 |
self.variance_epsilon = eps
|
89 |
|
90 |
def forward(self, hidden_states):
|
|
|
92 |
hidden_states = hidden_states.to(torch.float32)
|
93 |
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
94 |
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
95 |
+
return (self.weight * hidden_states).to(input_dtype)
|
96 |
|
97 |
|
98 |
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Grok
|
|
|
338 |
self.top_k = config.num_experts_per_tok
|
339 |
|
340 |
self.multi_head_attention = GrokAttention(config, layer_idx)
|
341 |
+
self.router = nn.Linear(self.hidden_size, self.num_experts, dtype=torch.float32, bias=False)
|
342 |
self.moe = nn.ModuleList([GrokBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
|
343 |
|
344 |
self.rms_norm = GrokRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
400 |
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
401 |
hidden_states = hidden_states.view(-1, hidden_dim)
|
402 |
# router_logits: (batch * sequence_length, n_experts)
|
403 |
+
router_logits = self.router(hidden_states.to(torch.float))
|
404 |
|
405 |
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
406 |
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|