mazesmazes commited on
Commit
f35df07
·
verified ·
1 Parent(s): f61bf72

Delete shared_moe_projector.py

Browse files
Files changed (1) hide show
  1. shared_moe_projector.py +0 -182
shared_moe_projector.py DELETED
@@ -1,182 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F # noqa: N812
4
-
5
-
6
- class SwiGLUExpert(nn.Module):
7
- """SwiGLU expert MLP (used for both shared and routed experts)."""
8
-
9
- def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
10
- super().__init__()
11
- self.gate_proj = nn.Linear(input_dim, hidden_dim, bias=False)
12
- self.up_proj = nn.Linear(input_dim, hidden_dim, bias=False)
13
- self.down_proj = nn.Linear(hidden_dim, output_dim, bias=False)
14
- self.act = nn.SiLU()
15
-
16
- def forward(self, x: torch.Tensor) -> torch.Tensor:
17
- return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
18
-
19
-
20
- class SharedMoEBlock(nn.Module):
21
- """MoE block with shared expert + sparse routed experts."""
22
-
23
- def __init__(
24
- self,
25
- input_dim: int,
26
- hidden_dim: int,
27
- output_dim: int,
28
- num_experts: int = 4,
29
- top_k: int = 2,
30
- ):
31
- super().__init__()
32
- self.num_experts = num_experts
33
- self.top_k = top_k
34
- self.output_dim = output_dim
35
-
36
- # Router: zero-initialized for natural learning
37
- self.router = nn.Linear(input_dim, num_experts, bias=False)
38
- nn.init.zeros_(self.router.weight)
39
-
40
- # Shared expert (always active)
41
- self.shared_expert = SwiGLUExpert(input_dim, hidden_dim, output_dim)
42
-
43
- # Routed experts (sparse)
44
- self.experts = nn.ModuleList(
45
- [SwiGLUExpert(input_dim, hidden_dim, output_dim) for _ in range(num_experts)]
46
- )
47
-
48
- # For auxiliary loss (cached to avoid recomputation)
49
- self.last_router_logits = None
50
- self.last_router_probs = None
51
-
52
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
53
- batch_size, seq_len, dim = hidden_states.shape
54
-
55
- # Shared expert output (all tokens)
56
- shared_out = self.shared_expert(hidden_states)
57
-
58
- # Routing
59
- flat_hidden = hidden_states.view(-1, dim)
60
- router_logits = self.router(flat_hidden)
61
- router_probs = F.softmax(router_logits.float(), dim=-1)
62
-
63
- # Cache for aux loss
64
- self.last_router_logits = router_logits
65
- self.last_router_probs = router_probs
66
-
67
- # Top-k selection and renormalization
68
- top_k_weights, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1)
69
- top_k_weights = top_k_weights / top_k_weights.sum(dim=-1, keepdim=True)
70
- top_k_weights = top_k_weights.to(hidden_states.dtype)
71
-
72
- # Routed expert output via token dispatch
73
- routed_out = self._dispatch_experts(flat_hidden, top_k_indices, top_k_weights)
74
- routed_out = routed_out.view(batch_size, seq_len, -1)
75
-
76
- # Combine: shared expert baseline + routed experts (grow in via zero-init down_proj)
77
- return shared_out + routed_out
78
-
79
- def _dispatch_experts(
80
- self,
81
- hidden_states: torch.Tensor,
82
- top_k_indices: torch.Tensor,
83
- top_k_weights: torch.Tensor,
84
- ) -> torch.Tensor:
85
- """Token dispatch - gather tokens per expert, process, scatter back."""
86
- num_tokens = hidden_states.shape[0]
87
- output = torch.zeros(
88
- num_tokens, self.output_dim, device=hidden_states.device, dtype=hidden_states.dtype
89
- )
90
-
91
- for expert_idx, expert in enumerate(self.experts):
92
- expert_mask = top_k_indices == expert_idx
93
- if not expert_mask.any():
94
- continue
95
-
96
- token_indices, slot_indices = torch.where(expert_mask)
97
- expert_input = hidden_states[token_indices]
98
- expert_output = expert(expert_input)
99
- weights = top_k_weights[token_indices, slot_indices].unsqueeze(-1)
100
- output.index_add_(0, token_indices, expert_output * weights)
101
-
102
- return output
103
-
104
-
105
- def load_balancing_loss(router_probs: torch.Tensor, num_experts: int, top_k: int) -> torch.Tensor:
106
- """Auxiliary loss to encourage balanced expert usage."""
107
- _, selected = torch.topk(router_probs, top_k, dim=-1)
108
- expert_mask = F.one_hot(selected, num_experts).float()
109
- tokens_per_expert = expert_mask.mean(dim=(0, 1))
110
- prob_per_expert = router_probs.mean(dim=0)
111
- return (tokens_per_expert * prob_per_expert).sum() * num_experts
112
-
113
-
114
- def z_loss(router_logits: torch.Tensor) -> torch.Tensor:
115
- """Z-loss to prevent router logits from growing too large."""
116
- return torch.logsumexp(router_logits.float(), dim=-1).square().mean()
117
-
118
-
119
- class SharedMoEAudioProjector(nn.Module):
120
- def __init__(self, config):
121
- super().__init__()
122
-
123
- # Temporal downsampling
124
- self.k = getattr(config, "projector_pool_stride", 4)
125
-
126
- # Dimensions
127
- encoder_dim = config.encoder_dim
128
- in_dim = encoder_dim * self.k
129
- out_dim = config.llm_dim
130
- hidden_dim = getattr(config, "projector_hidden_dim", None) or in_dim
131
-
132
- # MoE config
133
- self.num_experts = getattr(config, "num_experts", 4)
134
- self.top_k = getattr(config, "num_experts_per_tok", 2)
135
- self.aux_loss_coef = getattr(config, "router_aux_loss_coef", 0.02)
136
- self.z_loss_coef = getattr(config, "router_z_loss_coef", 0.001)
137
-
138
- # Layers
139
- self.moe = SharedMoEBlock(in_dim, hidden_dim, out_dim, self.num_experts, self.top_k)
140
-
141
- # Init
142
- self._init_weights(in_dim)
143
-
144
- def _init_weights(self, in_dim: int):
145
- with torch.no_grad():
146
- # Shared expert - orthogonal init for stable condition numbers
147
- nn.init.orthogonal_(self.moe.shared_expert.gate_proj.weight)
148
- nn.init.orthogonal_(self.moe.shared_expert.up_proj.weight)
149
- nn.init.orthogonal_(self.moe.shared_expert.down_proj.weight, gain=0.5)
150
-
151
- # Routed experts - orthogonal for gate/up, tiny orthogonal for down (grow-in)
152
- # gain=0.01 gives ~1% initial contribution while maintaining good conditioning
153
- for expert in self.moe.experts:
154
- nn.init.orthogonal_(expert.gate_proj.weight)
155
- nn.init.orthogonal_(expert.up_proj.weight)
156
- nn.init.orthogonal_(expert.down_proj.weight, gain=0.01)
157
-
158
- def forward(self, x: torch.Tensor) -> torch.Tensor:
159
- batch_size, seq_len, dim = x.size()
160
-
161
- target_dtype = self.moe.shared_expert.gate_proj.weight.dtype
162
- if x.dtype != target_dtype:
163
- x = x.to(target_dtype)
164
-
165
- # Pad for pooling (at most k-1 frames -> 1 extra token, negligible impact)
166
- if seq_len % self.k:
167
- x = F.pad(x, (0, 0, 0, self.k - seq_len % self.k))
168
-
169
- # Temporal pooling
170
- x = x.view(batch_size, -1, dim * self.k)
171
-
172
- return self.moe(x)
173
-
174
- def get_aux_loss(self) -> torch.Tensor:
175
- """Get auxiliary losses (call after forward)."""
176
- if self.moe.last_router_logits is None:
177
- return torch.tensor(0.0, device=self.moe.router.weight.device)
178
-
179
- balance = load_balancing_loss(self.moe.last_router_probs, self.num_experts, self.top_k)
180
- z = z_loss(self.moe.last_router_logits)
181
-
182
- return self.aux_loss_coef * balance + self.z_loss_coef * z