Charlie81 commited on
Commit
7bf23fe
·
1 Parent(s): 170c7d7

Revert "refactor sparse"

Browse files

This reverts commit 170c7d7f55aeef1ca17e395ad279ca2098e57d53.

Files changed (1) hide show
  1. modeling_myolmoe.py +12 -11
modeling_myolmoe.py CHANGED
@@ -223,7 +223,6 @@ class MyOLMoERouting(nn.Module):
223
  self.hidden_size = config.hidden_size
224
  self.routing_type = getattr(config, "routing_type", "sparse")
225
  self.router_temperature = getattr(config, "router_temperature", 1.0)
226
- self.norm_topk_prob = getattr(config, "norm_topk_prob", False)
227
 
228
  # Shared components
229
  self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
@@ -231,13 +230,20 @@ class MyOLMoERouting(nn.Module):
231
  # For non-deterministic routing
232
  self.gumbel_noise = getattr(config, "gumbel_noise", 0.1)
233
 
234
- def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
235
  batch_size, sequence_length, hidden_dim = hidden_states.shape
236
- print("TEST testtest123")
237
  hidden_states = hidden_states.view(-1, hidden_dim)
238
- print("TEST 4564645testtest123")
239
  router_logits = self.gate(hidden_states)
240
 
 
 
 
 
 
 
 
 
 
241
  if self.routing_type == "dense":
242
  # Dense routing - use all experts equally
243
  routing_weights = torch.ones_like(router_logits) / self.num_experts
@@ -256,16 +262,11 @@ class MyOLMoERouting(nn.Module):
256
 
257
  else: # Default sparse routing
258
  # Standard sparse top-k routing
259
- routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float)
260
  routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
261
 
262
- if self.norm_topk_prob:
263
- routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
264
-
265
- routing_weights = routing_weights.to(hidden_states.dtype)
266
-
267
  return routing_weights, selected_experts, router_logits
268
-
269
  class OlmoeRotaryEmbedding(nn.Module):
270
  def __init__(self, config: OlmoeConfig, device=None):
271
  super().__init__()
 
223
  self.hidden_size = config.hidden_size
224
  self.routing_type = getattr(config, "routing_type", "sparse")
225
  self.router_temperature = getattr(config, "router_temperature", 1.0)
 
226
 
227
  # Shared components
228
  self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
 
230
  # For non-deterministic routing
231
  self.gumbel_noise = getattr(config, "gumbel_noise", 0.1)
232
 
233
+ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
234
  batch_size, sequence_length, hidden_dim = hidden_states.shape
 
235
  hidden_states = hidden_states.view(-1, hidden_dim)
 
236
  router_logits = self.gate(hidden_states)
237
 
238
+ # Always use softmax, even for "dense" routing
239
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
240
+ routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
241
+
242
+ if self.norm_topk_prob:
243
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
244
+
245
+ routing_weights = routing_weights.to(hidden_states.dtype)
246
+
247
  if self.routing_type == "dense":
248
  # Dense routing - use all experts equally
249
  routing_weights = torch.ones_like(router_logits) / self.num_experts
 
262
 
263
  else: # Default sparse routing
264
  # Standard sparse top-k routing
265
+ routing_weights = F.softmax(router_logits, dim=-1)
266
  routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
267
 
 
 
 
 
 
268
  return routing_weights, selected_experts, router_logits
269
+
270
  class OlmoeRotaryEmbedding(nn.Module):
271
  def __init__(self, config: OlmoeConfig, device=None):
272
  super().__init__()