Crystalcareai commited on
Commit
9d9c0e7
1 Parent(s): 7723261

Update modeling_gemmoe.py

Browse files
Files changed (1) hide show
  1. modeling_gemmoe.py +114 -31
modeling_gemmoe.py CHANGED
@@ -194,42 +194,54 @@ class GemmoeRMSNorm(nn.Module):
194
 
195
  ALL_LAYERNORM_LAYERS.append(GemmoeRMSNorm)
196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  class GemmoeRotaryEmbedding(nn.Module):
198
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
199
  super().__init__()
200
-
201
  self.dim = dim
202
  self.max_position_embeddings = max_position_embeddings
203
  self.base = base
204
- inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
205
- self.register_buffer("inv_freq", inv_freq, persistent=False)
206
-
207
- # Build here to make `torch.jit.trace` work.
208
- self._set_cos_sin_cache(
209
- seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
210
- )
211
- self.max_seq_len_cached = None
212
-
213
 
214
  def _set_cos_sin_cache(self, seq_len, device, dtype):
215
  self.max_seq_len_cached = seq_len
216
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
217
-
218
- freqs = torch.outer(t, self.inv_freq.to(t.device))
219
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
220
- emb = torch.cat((freqs, freqs), dim=-1)
221
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
222
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
223
-
224
- def forward(self, x, seq_len=None):
225
- # x: [bs, num_attention_heads, seq_len, head_size]
226
- if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached:
 
 
 
 
 
 
227
  self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
228
-
229
  return (
230
- self.cos_cached[:seq_len].to(dtype=x.dtype),
231
- self.sin_cached[:seq_len].to(dtype=x.dtype),
232
  )
 
233
 
234
  class GemmoeLinearScalingRotaryEmbedding(GemmoeRotaryEmbedding):
235
  """GemmoeRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
@@ -948,17 +960,78 @@ GEMMOE_ATTENTION_CLASSES = {
948
  "sdpa": GemmoeSdpaAttention,
949
  }
950
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
951
 
952
  class GemmoeDecoderLayer(nn.Module):
953
  def __init__(self, config: GemmoeConfig, layer_idx: int):
954
  super().__init__()
955
  self.hidden_size = config.hidden_size
956
-
957
  self.self_attn = GEMMOE_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
958
 
959
- self.mlp = GemMoE(config) if (config.n_routed_experts is not None and \
960
- layer_idx >= config.first_k_dense_replace and layer_idx % config.moe_layer_freq == 0) \
961
- else GemmoeMLP(config)
 
 
 
 
962
  self.input_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
963
  self.post_attention_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
964
 
@@ -969,6 +1042,7 @@ class GemmoeDecoderLayer(nn.Module):
969
  position_ids: Optional[torch.LongTensor] = None,
970
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
971
  output_attentions: Optional[bool] = False,
 
972
  use_cache: Optional[bool] = False,
973
  **kwargs,
974
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
@@ -985,13 +1059,15 @@ class GemmoeDecoderLayer(nn.Module):
985
  If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
986
  (see `past_key_values`).
987
  past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
 
 
 
988
  """
989
  if "padding_mask" in kwargs:
990
  warnings.warn(
991
  "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
992
  )
993
  residual = hidden_states
994
-
995
  hidden_states = self.input_layernorm(hidden_states)
996
 
997
  # Self Attention
@@ -1009,7 +1085,12 @@ class GemmoeDecoderLayer(nn.Module):
1009
  # Fully Connected
1010
  residual = hidden_states
1011
  hidden_states = self.post_attention_layernorm(hidden_states)
1012
- hidden_states = self.mlp(hidden_states)
 
 
 
 
 
1013
  hidden_states = residual + hidden_states
1014
 
1015
  outputs = (hidden_states,)
@@ -1019,10 +1100,12 @@ class GemmoeDecoderLayer(nn.Module):
1019
 
1020
  if use_cache:
1021
  outputs += (present_key_value,)
 
 
 
1022
 
1023
  return outputs
1024
 
1025
-
1026
  GEMMOE_START_DOCSTRING = r"""
1027
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1028
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
 
194
 
195
  ALL_LAYERNORM_LAYERS.append(GemmoeRMSNorm)
196
 
197
+ class GemmoeRMSNorm(nn.Module):
198
+ def __init__(self, dim: int, eps: float = 1e-6):
199
+ super().__init__()
200
+ self.eps = eps
201
+ self.weight = nn.Parameter(torch.zeros(dim))
202
+
203
+ def _norm(self, x):
204
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
205
+
206
+ def forward(self, x):
207
+ output = self._norm(x.float()).type_as(x)
208
+ return output * (self.weight + 1)
209
+
210
+ ALL_LAYERNORM_LAYERS.append(GemmoeRMSNorm)
211
+
212
  class GemmoeRotaryEmbedding(nn.Module):
213
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
214
  super().__init__()
 
215
  self.dim = dim
216
  self.max_position_embeddings = max_position_embeddings
217
  self.base = base
218
+ self._set_cos_sin_cache(seq_len=max_position_embeddings, device=device, dtype=torch.get_default_dtype())
 
 
 
 
 
 
 
 
219
 
220
  def _set_cos_sin_cache(self, seq_len, device, dtype):
221
  self.max_seq_len_cached = seq_len
222
+ freq_exponents = (2.0 / self.dim) * (
223
+ torch.arange(self.dim // 2, dtype=torch.int64, device="cpu").float()
224
+ )
225
+ timescale = self.base ** freq_exponents
226
+ positions = torch.arange(self.max_seq_len_cached, device="cpu", dtype=torch.int64).float()
227
+ radians_new = positions[..., None] / timescale[None, None, :]
228
+ radians_new = radians_new.squeeze(0)
229
+ emb = torch.cat((radians_new, radians_new), dim=-1)
230
+ cos = emb.cos().to(device=device, non_blocking=True)
231
+ sin = emb.sin().to(device=device, non_blocking=True)
232
+ self.register_buffer("cos_cached", cos, persistent=False)
233
+ self.register_buffer("sin_cached", sin, persistent=False)
234
+
235
+ def forward(self, x, position_ids=None, seq_len=None):
236
+ if seq_len is None:
237
+ seq_len = x.size(2)
238
+ if seq_len > self.max_seq_len_cached:
239
  self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
 
240
  return (
241
+ self.cos_cached[:seq_len],
242
+ self.sin_cached[:seq_len],
243
  )
244
+
245
 
246
  class GemmoeLinearScalingRotaryEmbedding(GemmoeRotaryEmbedding):
247
  """GemmoeRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
 
960
  "sdpa": GemmoeSdpaAttention,
961
  }
962
 
963
+ class GemmoeBlockSparseTop2MLP(nn.Module):
964
+ def __init__(self, config: GemmoeConfig):
965
+ super().__init__()
966
+ self.ffn_dim = config.intermediate_size
967
+ self.hidden_dim = config.hidden_size
968
+
969
+ self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
970
+ self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
971
+ self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
972
+
973
+ self.act_fn = approx_gelu
974
+
975
+ def forward(self, hidden_states):
976
+ current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
977
+ current_hidden_states = self.w2(current_hidden_states)
978
+ return current_hidden_states
979
+
980
+ class GemmoeSparseMoeBlock(nn.Module):
981
+ def __init__(self, config):
982
+ super().__init__()
983
+ self.hidden_dim = config.hidden_size
984
+ self.ffn_dim = config.intermediate_size
985
+ self.num_experts = config.num_local_experts
986
+ self.top_k = 2
987
+
988
+ # gating
989
+ self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
990
+
991
+ self.experts = nn.ModuleList([GemmoeBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
992
+
993
+ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
994
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
995
+ hidden_states = hidden_states.view(-1, hidden_dim)
996
+
997
+ # router_logits: (batch * sequence_length, n_experts)
998
+ router_logits = self.gate(hidden_states)
999
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
1000
+ topk_weight, topk_idx = torch.topk(routing_weights, self.top_k, dim=-1, sorted=False)
1001
+ topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
1002
+
1003
+ # we cast back to the input dtype
1004
+ topk_weight = topk_weight.to(hidden_states.dtype)
1005
+
1006
+ hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
1007
+
1008
+ y = torch.empty_like(hidden_states)
1009
+
1010
+ flat_topk_idx = topk_idx.view(-1)
1011
+ for i in range(self.num_experts):
1012
+ expert = self.experts[i]
1013
+ expert_output = expert(hidden_states[flat_topk_idx == i])
1014
+ y[flat_topk_idx == i] = expert_output.to(y.dtype) # Cast expert_output to the same dtype as y
1015
+
1016
+ y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
1017
+
1018
+ final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)
1019
+ return final_hidden_states, router_logits
1020
+
1021
 
1022
  class GemmoeDecoderLayer(nn.Module):
1023
  def __init__(self, config: GemmoeConfig, layer_idx: int):
1024
  super().__init__()
1025
  self.hidden_size = config.hidden_size
 
1026
  self.self_attn = GEMMOE_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
1027
 
1028
+ if config.n_routed_experts is not None and \
1029
+ layer_idx >= config.first_k_dense_replace and \
1030
+ layer_idx % config.moe_layer_freq == 0:
1031
+ self.block_sparse_moe = GemmoeSparseMoeBlock(config)
1032
+ else:
1033
+ self.mlp = GemmoeMLP(config)
1034
+
1035
  self.input_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1036
  self.post_attention_layernorm = GemmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1037
 
 
1042
  position_ids: Optional[torch.LongTensor] = None,
1043
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
1044
  output_attentions: Optional[bool] = False,
1045
+ output_router_logits: Optional[bool] = False,
1046
  use_cache: Optional[bool] = False,
1047
  **kwargs,
1048
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
 
1059
  If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1060
  (see `past_key_values`).
1061
  past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
1062
+ output_router_logits (`bool`, *optional*):
1063
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss,
1064
+ and should not be returned during inference.
1065
  """
1066
  if "padding_mask" in kwargs:
1067
  warnings.warn(
1068
  "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
1069
  )
1070
  residual = hidden_states
 
1071
  hidden_states = self.input_layernorm(hidden_states)
1072
 
1073
  # Self Attention
 
1085
  # Fully Connected
1086
  residual = hidden_states
1087
  hidden_states = self.post_attention_layernorm(hidden_states)
1088
+
1089
+ if hasattr(self, 'block_sparse_moe'):
1090
+ hidden_states, router_logits = self.block_sparse_moe(hidden_states)
1091
+ else:
1092
+ hidden_states = self.mlp(hidden_states)
1093
+
1094
  hidden_states = residual + hidden_states
1095
 
1096
  outputs = (hidden_states,)
 
1100
 
1101
  if use_cache:
1102
  outputs += (present_key_value,)
1103
+
1104
+ if output_router_logits and hasattr(self, 'block_sparse_moe'):
1105
+ outputs += (router_logits,)
1106
 
1107
  return outputs
1108
 
 
1109
  GEMMOE_START_DOCSTRING = r"""
1110
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1111
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads