Crystalcareai commited on
Commit
6dc0ddc
1 Parent(s): f6ef932

Update modeling_gemmoe.py

Browse files
Files changed (1) hide show
  1. modeling_gemmoe.py +8 -18
modeling_gemmoe.py CHANGED
@@ -670,16 +670,11 @@ class GemmoeBlockSparseTop2MLP(nn.Module):
670
  self.act_fn = approx_gelu
671
 
672
  def forward(self, hidden_states):
 
673
  current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
674
- current_hidden_states = self.w2(current_hidden_states)
675
  return current_hidden_states
676
 
677
- class GemmoeBlockSparseTop2MLP(GemmoeBlockSparseTop2MLP):
678
- def __init__(self, *args, **kwargs):
679
- logger.warning_once(
680
- "GemmoeBLockSparseTop2MLP is deprecated by GemmoeBlockSparseTop2MLP and will be removed in v4.40."
681
- )
682
- super().__init__(*args, **kwargs)
683
 
684
  class GemmoeSparseMoeBlock(nn.Module):
685
  def __init__(self, config):
@@ -699,8 +694,9 @@ class GemmoeSparseMoeBlock(nn.Module):
699
  hidden_states = hidden_states.view(-1, hidden_dim)
700
 
701
  # router_logits: (batch * sequence_length, n_experts)
702
- router_logits = self.gate(hidden_states)
703
- routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
 
704
  topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False)
705
  topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
706
 
@@ -715,7 +711,7 @@ class GemmoeSparseMoeBlock(nn.Module):
715
  for i in range(self.num_experts):
716
  expert = self.experts[i]
717
  expert_output = expert(hidden_states[flat_topk_idx == i])
718
- y[flat_topk_idx == i] = expert_output.to(y.dtype) # Cast expert_output to the same dtype as y
719
 
720
  y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
721
 
@@ -983,7 +979,6 @@ class GemmoeModel(GemmoePreTrainedModel):
983
  self.embed_tokens = value
984
 
985
  @add_start_docstrings_to_model_forward(GEMMOE_INPUTS_DOCSTRING)
986
- # Ignore copy
987
  def forward(
988
  self,
989
  input_ids: torch.LongTensor = None,
@@ -994,7 +989,7 @@ class GemmoeModel(GemmoePreTrainedModel):
994
  use_cache: Optional[bool] = None,
995
  output_attentions: Optional[bool] = None,
996
  output_hidden_states: Optional[bool] = None,
997
- output_router_logits: Optional[bool] = None, # Add this line
998
  return_dict: Optional[bool] = None,
999
  cache_position: Optional[torch.LongTensor] = None,
1000
  ) -> Union[Tuple, MoeModelOutputWithPast]:
@@ -1023,7 +1018,6 @@ class GemmoeModel(GemmoePreTrainedModel):
1023
  # Fix for precision issue when casting to bfloat16
1024
  hidden_size_sqrt = math.sqrt(self.config.hidden_size)
1025
  if inputs_embeds.dtype == torch.bfloat16:
1026
-
1027
  pass
1028
 
1029
  hidden_states = inputs_embeds * hidden_size_sqrt
@@ -1110,10 +1104,6 @@ class GemmoeModel(GemmoePreTrainedModel):
1110
  attentions=all_self_attns,
1111
  )
1112
 
1113
- # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
1114
- # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
1115
- # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
1116
- # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
1117
  def _update_causal_mask(self, attention_mask, input_tensor):
1118
  if self.config._attn_implementation == "flash_attention_2":
1119
  if attention_mask is not None and 0.0 in attention_mask:
@@ -1135,7 +1125,7 @@ class GemmoeModel(GemmoePreTrainedModel):
1135
  causal_mask = self.causal_mask[None, None, :, :].to(dtype=dtype, device=device) * min_dtype
1136
  causal_mask = causal_mask.expand(batch_size, 1, -1, -1)
1137
  if attention_mask is not None:
1138
- causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
1139
  if attention_mask.dim() == 2:
1140
  mask_length = attention_mask.shape[-1]
1141
  padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
 
670
  self.act_fn = approx_gelu
671
 
672
  def forward(self, hidden_states):
673
+ hidden_states = hidden_states.to(torch.float32) # Cast to float32
674
  current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
675
+ current_hidden_states = self.w2(current_hidden_states.to(hidden_states.dtype)) # Cast back to original dtype
676
  return current_hidden_states
677
 
 
 
 
 
 
 
678
 
679
  class GemmoeSparseMoeBlock(nn.Module):
680
  def __init__(self, config):
 
694
  hidden_states = hidden_states.view(-1, hidden_dim)
695
 
696
  # router_logits: (batch * sequence_length, n_experts)
697
+ hidden_states_float = hidden_states.float() # Cast to float32
698
+ router_logits = self.gate(hidden_states_float)
699
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float32)
700
  topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False)
701
  topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
702
 
 
711
  for i in range(self.num_experts):
712
  expert = self.experts[i]
713
  expert_output = expert(hidden_states[flat_topk_idx == i])
714
+ y[flat_topk_idx == i] = expert_output
715
 
716
  y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
717
 
 
979
  self.embed_tokens = value
980
 
981
  @add_start_docstrings_to_model_forward(GEMMOE_INPUTS_DOCSTRING)
 
982
  def forward(
983
  self,
984
  input_ids: torch.LongTensor = None,
 
989
  use_cache: Optional[bool] = None,
990
  output_attentions: Optional[bool] = None,
991
  output_hidden_states: Optional[bool] = None,
992
+ output_router_logits: Optional[bool] = None,
993
  return_dict: Optional[bool] = None,
994
  cache_position: Optional[torch.LongTensor] = None,
995
  ) -> Union[Tuple, MoeModelOutputWithPast]:
 
1018
  # Fix for precision issue when casting to bfloat16
1019
  hidden_size_sqrt = math.sqrt(self.config.hidden_size)
1020
  if inputs_embeds.dtype == torch.bfloat16:
 
1021
  pass
1022
 
1023
  hidden_states = inputs_embeds * hidden_size_sqrt
 
1104
  attentions=all_self_attns,
1105
  )
1106
 
 
 
 
 
1107
  def _update_causal_mask(self, attention_mask, input_tensor):
1108
  if self.config._attn_implementation == "flash_attention_2":
1109
  if attention_mask is not None and 0.0 in attention_mask:
 
1125
  causal_mask = self.causal_mask[None, None, :, :].to(dtype=dtype, device=device) * min_dtype
1126
  causal_mask = causal_mask.expand(batch_size, 1, -1, -1)
1127
  if attention_mask is not None:
1128
+ causal_mask = causal_mask.clone()
1129
  if attention_mask.dim() == 2:
1130
  mask_length = attention_mask.shape[-1]
1131
  padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)