Revert "refactor sparse"
Browse filesThis reverts commit 170c7d7f55aeef1ca17e395ad279ca2098e57d53.
- modeling_myolmoe.py +12 -11
modeling_myolmoe.py
CHANGED
|
@@ -223,7 +223,6 @@ class MyOLMoERouting(nn.Module):
|
|
| 223 |
self.hidden_size = config.hidden_size
|
| 224 |
self.routing_type = getattr(config, "routing_type", "sparse")
|
| 225 |
self.router_temperature = getattr(config, "router_temperature", 1.0)
|
| 226 |
-
self.norm_topk_prob = getattr(config, "norm_topk_prob", False)
|
| 227 |
|
| 228 |
# Shared components
|
| 229 |
self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
|
|
@@ -231,13 +230,20 @@ class MyOLMoERouting(nn.Module):
|
|
| 231 |
# For non-deterministic routing
|
| 232 |
self.gumbel_noise = getattr(config, "gumbel_noise", 0.1)
|
| 233 |
|
| 234 |
-
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor
|
| 235 |
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
| 236 |
-
print("TEST testtest123")
|
| 237 |
hidden_states = hidden_states.view(-1, hidden_dim)
|
| 238 |
-
print("TEST 4564645testtest123")
|
| 239 |
router_logits = self.gate(hidden_states)
|
| 240 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
if self.routing_type == "dense":
|
| 242 |
# Dense routing - use all experts equally
|
| 243 |
routing_weights = torch.ones_like(router_logits) / self.num_experts
|
|
@@ -256,16 +262,11 @@ class MyOLMoERouting(nn.Module):
|
|
| 256 |
|
| 257 |
else: # Default sparse routing
|
| 258 |
# Standard sparse top-k routing
|
| 259 |
-
routing_weights = F.softmax(router_logits, dim=-1
|
| 260 |
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
| 261 |
|
| 262 |
-
if self.norm_topk_prob:
|
| 263 |
-
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
| 264 |
-
|
| 265 |
-
routing_weights = routing_weights.to(hidden_states.dtype)
|
| 266 |
-
|
| 267 |
return routing_weights, selected_experts, router_logits
|
| 268 |
-
|
| 269 |
class OlmoeRotaryEmbedding(nn.Module):
|
| 270 |
def __init__(self, config: OlmoeConfig, device=None):
|
| 271 |
super().__init__()
|
|
|
|
| 223 |
self.hidden_size = config.hidden_size
|
| 224 |
self.routing_type = getattr(config, "routing_type", "sparse")
|
| 225 |
self.router_temperature = getattr(config, "router_temperature", 1.0)
|
|
|
|
| 226 |
|
| 227 |
# Shared components
|
| 228 |
self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
|
|
|
|
| 230 |
# For non-deterministic routing
|
| 231 |
self.gumbel_noise = getattr(config, "gumbel_noise", 0.1)
|
| 232 |
|
| 233 |
+
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 234 |
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
|
|
|
| 235 |
hidden_states = hidden_states.view(-1, hidden_dim)
|
|
|
|
| 236 |
router_logits = self.gate(hidden_states)
|
| 237 |
|
| 238 |
+
# Always use softmax, even for "dense" routing
|
| 239 |
+
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
| 240 |
+
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
| 241 |
+
|
| 242 |
+
if self.norm_topk_prob:
|
| 243 |
+
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
| 244 |
+
|
| 245 |
+
routing_weights = routing_weights.to(hidden_states.dtype)
|
| 246 |
+
|
| 247 |
if self.routing_type == "dense":
|
| 248 |
# Dense routing - use all experts equally
|
| 249 |
routing_weights = torch.ones_like(router_logits) / self.num_experts
|
|
|
|
| 262 |
|
| 263 |
else: # Default sparse routing
|
| 264 |
# Standard sparse top-k routing
|
| 265 |
+
routing_weights = F.softmax(router_logits, dim=-1)
|
| 266 |
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
| 267 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
return routing_weights, selected_experts, router_logits
|
| 269 |
+
|
| 270 |
class OlmoeRotaryEmbedding(nn.Module):
|
| 271 |
def __init__(self, config: OlmoeConfig, device=None):
|
| 272 |
super().__init__()
|