ver217 commited on
Commit
527550f
1 Parent(s): 25d1d74

[hotfix] update modeling

Browse files
Files changed (2) hide show
  1. configuration_grok1.py +4 -0
  2. modeling_grok1.py +4 -2
configuration_grok1.py CHANGED
@@ -16,6 +16,8 @@ class Grok1Config(PretrainedConfig):
16
  attn_output_multiplier=1.0,
17
  max_attn_value=1.0,
18
  max_position_embeddings=4096,
 
 
19
  rms_norm_eps=1e-5,
20
  use_cache=True,
21
  pad_token_id=None,
@@ -32,6 +34,8 @@ class Grok1Config(PretrainedConfig):
32
  self.attn_output_multiplier = attn_output_multiplier
33
  self.max_attn_value = max_attn_value
34
  self.max_position_embeddings = max_position_embeddings
 
 
35
  self.hidden_size = hidden_size
36
  self.widening_factor = widening_factor
37
  self.num_hidden_layers = num_hidden_layers
 
16
  attn_output_multiplier=1.0,
17
  max_attn_value=1.0,
18
  max_position_embeddings=4096,
19
+ embedding_multiplier_scale: float = 1.0,
20
+ output_multiplier_scale: float = 1.0,
21
  rms_norm_eps=1e-5,
22
  use_cache=True,
23
  pad_token_id=None,
 
34
  self.attn_output_multiplier = attn_output_multiplier
35
  self.max_attn_value = max_attn_value
36
  self.max_position_embeddings = max_position_embeddings
37
+ self.embedding_multiplier_scale = embedding_multiplier_scale
38
+ self.output_multiplier_scale = output_multiplier_scale
39
  self.hidden_size = hidden_size
40
  self.widening_factor = widening_factor
41
  self.num_hidden_layers = num_hidden_layers
modeling_grok1.py CHANGED
@@ -259,8 +259,6 @@ class MultiHeadAttention(nn.Module):
259
 
260
  past_key_value = (key_states, value_states) if use_cache else None
261
 
262
- # TODO: repeat kv
263
-
264
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)).to(
265
  torch.float
266
  )
@@ -536,6 +534,7 @@ class Grok1Model(Grok1PretrainedModel):
536
  super().__init__(config)
537
  self.padding_idx = config.pad_token_id
538
  self.vocab_size = config.vocab_size
 
539
 
540
  self.embed_tokens = nn.Embedding(
541
  config.vocab_size, config.hidden_size, self.padding_idx
@@ -654,6 +653,7 @@ class Grok1Model(Grok1PretrainedModel):
654
 
655
  if inputs_embeds is None:
656
  inputs_embeds = self.embed_tokens(input_ids)
 
657
 
658
  if HAS_MASK_UTILS:
659
  # 4d mask is passed through the layers
@@ -772,6 +772,7 @@ class Grok1ModelForCausalLM(Grok1PretrainedModel):
772
  super().__init__(config)
773
  self.model = Grok1Model(config)
774
  self.vocab_size = config.vocab_size
 
775
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
776
  self.router_aux_loss_coef = config.router_aux_loss_coef
777
  self.num_experts = config.num_experts
@@ -846,6 +847,7 @@ class Grok1ModelForCausalLM(Grok1PretrainedModel):
846
 
847
  hidden_states = outputs[0]
848
  logits = self.lm_head(hidden_states)
 
849
  logits = logits.float()
850
 
851
  loss = None
 
259
 
260
  past_key_value = (key_states, value_states) if use_cache else None
261
 
 
 
262
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)).to(
263
  torch.float
264
  )
 
534
  super().__init__(config)
535
  self.padding_idx = config.pad_token_id
536
  self.vocab_size = config.vocab_size
537
+ self.embedding_multiplier_scale = config.embedding_multiplier_scale
538
 
539
  self.embed_tokens = nn.Embedding(
540
  config.vocab_size, config.hidden_size, self.padding_idx
 
653
 
654
  if inputs_embeds is None:
655
  inputs_embeds = self.embed_tokens(input_ids)
656
+ inputs_embeds = inputs_embeds * self.embedding_multiplier_scale
657
 
658
  if HAS_MASK_UTILS:
659
  # 4d mask is passed through the layers
 
772
  super().__init__(config)
773
  self.model = Grok1Model(config)
774
  self.vocab_size = config.vocab_size
775
+ self.output_multiplier_scale = config.output_multiplier_scale
776
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
777
  self.router_aux_loss_coef = config.router_aux_loss_coef
778
  self.num_experts = config.num_experts
 
847
 
848
  hidden_states = outputs[0]
849
  logits = self.lm_head(hidden_states)
850
+ logits = logits * self.output_multiplier_scale
851
  logits = logits.float()
852
 
853
  loss = None