Guo commited on
Commit
e815555
1 Parent(s): 90dabef
Files changed (3) hide show
  1. gate.py +100 -0
  2. modeling_jetmoe.py +5 -5
  3. moe.py +277 -0
gate.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class top_k_gating(nn.Module):
6
+ def __init__(
7
+ self,
8
+ input_size,
9
+ num_experts,
10
+ top_k,
11
+ ):
12
+ """
13
+ Initialize the top-k gating mechanism.
14
+
15
+ Args:
16
+ input_size (int): Size of the input.
17
+ num_experts (int): Number of experts.
18
+ top_k (int): Number of top experts to select.
19
+ acc_aux_loss (bool): Whether to accumulate auxiliary loss statistics.
20
+ dropout (float): Dropout rate for gating network.
21
+ hidden_size (int): Hidden size of the gating network.
22
+ sample_topk (int): Number of top-k experts to sample during training.
23
+ aux_loss (str): Type of auxiliary loss ('mi' or 'switch').
24
+ gate_type (str): Type of gating mechanism ('mlp', 'linear', or 'gmm').
25
+ """
26
+ super().__init__()
27
+
28
+ self.num_experts = num_experts
29
+ self.input_size = input_size
30
+ assert top_k <= num_experts
31
+ self.top_k = top_k
32
+
33
+ self.layer = nn.Linear(input_size, num_experts, bias=False)
34
+
35
+ def extra_repr(self):
36
+ """
37
+ Return extra representation string for the module.
38
+ """
39
+ return 'k={}, num_experts={}'.format(
40
+ self.top_k, self.num_experts)
41
+
42
+ def compute_aux_loss(self, probs, logits, gates):
43
+ """
44
+ Calculate and return the auxiliary loss based on the accumulated statistics.
45
+
46
+ Args:
47
+ eps (float): Small epsilon value for numerical stability.
48
+
49
+ Returns:
50
+ torch.Tensor: The calculated auxiliary loss.
51
+ """
52
+ count = logits.size(0)
53
+ probs = probs.sum(0)
54
+ freq = (gates > 0).float().sum(0)
55
+ lsesq = (torch.log(torch.exp(logits).sum(dim=-1)) ** 2).sum()
56
+
57
+ switchloss = self.num_experts * (
58
+ F.normalize(probs, p=1, dim=0) *
59
+ F.normalize(freq, p=1, dim=0)
60
+ ).sum()
61
+ zloss = lsesq / count
62
+ loss = switchloss + 0.1 * zloss
63
+
64
+ return loss
65
+
66
+ def forward(self, x):
67
+ """
68
+ Compute the top-k gating for the input.
69
+
70
+ See paper: https://arxiv.org/abs/1701.06538.
71
+
72
+ Args:
73
+ x (torch.Tensor): Input tensor with shape [batch_size, input_size].
74
+ skip_mask (torch.Tensor): Skip mask tensor (binary) with the same shape as `x`.
75
+ x: input Tensor with shape [batch_size, input_size]
76
+ train: a boolean - we only add noise at training time.
77
+ noise_epsilon: a float
78
+
79
+ Returns:
80
+ torch.Tensor: Top-k indices.
81
+ torch.Tensor: Top-k gating values.
82
+ torch.Tensor: Probability values for each expert.
83
+ gates: a Tensor with shape [batch_size, num_experts]
84
+ load: a Tensor with shape [num_experts]
85
+ """
86
+
87
+ logits = self.layer(x).float()
88
+ top_k_logits, top_k_indices = logits.topk(self.top_k, dim=1)
89
+ top_k_gates = torch.softmax(top_k_logits, dim=1).type_as(x)
90
+
91
+ if self.training:
92
+ probs = torch.softmax(logits, dim=1)
93
+ zeros = torch.zeros_like(probs)
94
+ zeros = zeros.to(top_k_gates.dtype) # Convert zeros to match top_k_gates dtype
95
+ gates = zeros.scatter(1, top_k_indices, top_k_gates)
96
+ self.loss = self.compute_aux_loss(probs, logits, gates)
97
+ else:
98
+ self.loss = 0
99
+
100
+ return top_k_indices, top_k_gates
modeling_jetmoe.py CHANGED
@@ -9,7 +9,7 @@ from torch import nn
9
  from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
10
  from torch.nn import functional as F
11
 
12
- import megablocks
13
  from transformers.modeling_outputs import (
14
  BaseModelOutputWithPast,
15
  CausalLMOutputWithPast,
@@ -28,7 +28,7 @@ from transformers.utils import (
28
  from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
29
  from transformers.cache_utils import Cache, DynamicCache
30
  from .configuration_jetmoe import JetMoEConfig
31
- from jetmoe_model.utils import moe
32
 
33
  if is_flash_attn_2_available():
34
  from flash_attn import flash_attn_func, flash_attn_varlen_func
@@ -701,9 +701,9 @@ class JetMoEBlock(nn.Module):
701
  self.self_attention = JETMOE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
702
  self.post_attention_layernorm = JetMoERMSNorm(config.hidden_size)
703
 
704
- moe_args = megablocks.layers.arguments.from_megatron(config)
705
- moe_args.activation_fn = F.silu
706
- moe_args.return_bias = False
707
  # self.mlp = megablocks.layers.dmoe.dMoE(moe_args)
708
  self.mlp = moe.MoE(
709
  input_size=config.hidden_size,
 
9
  from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
10
  from torch.nn import functional as F
11
 
12
+ #import megablocks
13
  from transformers.modeling_outputs import (
14
  BaseModelOutputWithPast,
15
  CausalLMOutputWithPast,
 
28
  from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
29
  from transformers.cache_utils import Cache, DynamicCache
30
  from .configuration_jetmoe import JetMoEConfig
31
+ from . import moe
32
 
33
  if is_flash_attn_2_available():
34
  from flash_attn import flash_attn_func, flash_attn_varlen_func
 
701
  self.self_attention = JETMOE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
702
  self.post_attention_layernorm = JetMoERMSNorm(config.hidden_size)
703
 
704
+ # moe_args = megablocks.layers.arguments.from_megatron(config)
705
+ # moe_args.activation_fn = F.silu
706
+ # moe_args.return_bias = False
707
  # self.mlp = megablocks.layers.dmoe.dMoE(moe_args)
708
  self.mlp = moe.MoE(
709
  input_size=config.hidden_size,
moe.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ import scattermoe
9
+ from .gate import top_k_gating
10
+
11
+
12
+ class MoE(nn.Module):
13
+ """
14
+ A Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts.
15
+
16
+
17
+ Args:
18
+ input_size: integer - size of the input
19
+ head_size: integer - size of the expert's hidden layer
20
+ num_experts: an integer - number of experts
21
+ top_k: an integer - how many experts to use for each batch element
22
+ bias: a boolean - whether to include bias in linear layers
23
+ activation: an activation function to apply to expert's outputs
24
+ acc_aux_loss: a boolean - whether to accumulate auxiliary loss
25
+ hidden_size: an integer - hidden size of the experts
26
+ gating_dropout: a float - dropout rate for gating network
27
+ sample_topk: an integer - how many experts to sample during training
28
+ gating_size: an integer - size of the gating network
29
+ aux_loss: a string - type of auxiliary loss ('mi' or 'sparse')
30
+ gate_type: a string - type of gating mechanism ('mlp' or 'topk')
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ input_size,
36
+ hidden_size,
37
+ num_experts,
38
+ top_k,
39
+ bias=True,
40
+ activation=None,
41
+ glu=True,
42
+ ):
43
+ super(MoE, self).__init__()
44
+
45
+ self.num_experts = num_experts
46
+ self.input_size = input_size
47
+ self.glu = glu
48
+ if bias:
49
+ self.bias = torch.nn.Parameter(torch.empty(input_size))
50
+ torch.nn.init.zeros_(self.bias)
51
+ else:
52
+ self.bias = None
53
+ self.input_linear = scattermoe.parallel_experts.ParallelExperts(num_experts, input_size, hidden_size * 2 if glu else hidden_size)
54
+ self.output_linear = scattermoe.parallel_experts.ParallelExperts(num_experts, hidden_size, input_size)
55
+ self.top_k = min(top_k, self.num_experts)
56
+ self.activation = activation
57
+
58
+ self.router = top_k_gating(
59
+ input_size=input_size,
60
+ num_experts=num_experts,
61
+ top_k=top_k,
62
+ )
63
+
64
+ def extra_repr(self):
65
+ return 'k={}, e={}'.format(
66
+ self.top_k, self.num_experts)
67
+
68
+ def get_aux_loss_and_clear(self):
69
+ """
70
+ Get the accumulated auxiliary loss and clear it.
71
+
72
+ Returns:
73
+ float: Accumulated auxiliary loss.
74
+ """
75
+
76
+ return self.gate.get_aux_loss_and_clear()
77
+
78
+ def compute_gate(self, x):
79
+ top_k_indices, self.top_k_gates = self.router(x)
80
+
81
+ with torch.no_grad():
82
+ self.sorted_expert_idxs, self.sorted_scattered_idxs = scattermoe.kernels.ops.flatten_and_sort(top_k_indices)
83
+ self.padded_block_idxs, self.expert_offsets = scattermoe.kernels.ops.padded_block_indices(self.sorted_expert_idxs, self.num_experts)
84
+
85
+ return self.router.loss
86
+
87
+ def batch_forward(self, x):
88
+ """
89
+ Forward pass of the mixture of experts layer.
90
+
91
+ Args:
92
+ x (Tensor): Input tensor.
93
+
94
+ Returns:
95
+ Tensor: Output tensor.
96
+ """
97
+ bsz, length, emb_size = x.size()
98
+ x = x.reshape(-1, emb_size)
99
+
100
+ loss = self.compute_gate(x)
101
+
102
+ h = self.input_linear(
103
+ x, self.top_k,
104
+ self.sorted_expert_idxs, self.sorted_scattered_idxs,
105
+ self.padded_block_idxs, self.expert_offsets,
106
+ grouped_out=True
107
+ )
108
+
109
+ if self.glu:
110
+ h, g = h.chunk(2, dim=-1)
111
+ h = self.activation(h) * g
112
+ else:
113
+ h = self.activation(h)
114
+
115
+ y = self.output_linear(
116
+ h, 1,
117
+ self.sorted_expert_idxs, self.sorted_scattered_idxs,
118
+ self.padded_block_idxs, self.expert_offsets,
119
+ grouped_in=True,
120
+ gates=self.top_k_gates,
121
+ )
122
+
123
+ y = y.view(bsz, length, self.input_size)
124
+ if self.bias is not None:
125
+ y = y + self.bias
126
+ return y, loss
127
+
128
+ def single_forward(self, x):
129
+ bsz, length, emb_size = x.size()
130
+
131
+ x = x.reshape(1, self.input_size)
132
+ top_k_indices, top_k_gates = self.router(x)
133
+ loss = self.router.loss
134
+
135
+ y_list = []
136
+ for i in range(self.top_k):
137
+ expert_idx = top_k_indices[0,i]
138
+
139
+ h = F.linear(x, self.input_linear.weight[expert_idx])
140
+ if self.glu:
141
+ h, g = h.chunk(2, dim=-1)
142
+ h = self.activation(h) * g
143
+ else:
144
+ h = self.activation(h)
145
+ y = F.linear(h, self.output_linear.weight[expert_idx]) * top_k_gates[0,i]
146
+
147
+ y_list.append(y)
148
+
149
+ y = sum(y_list)
150
+ y = y.view(bsz, length, self.input_size)
151
+ if self.bias is not None:
152
+ y = y + self.bias
153
+ return y, loss
154
+
155
+ def forward(self, x):
156
+ """
157
+ Forward pass of the mixture of experts layer.
158
+
159
+ Args:
160
+ x (Tensor): Input tensor.
161
+
162
+ Returns:
163
+ Tensor: Output tensor.
164
+ """
165
+ bsz, length, emb_size = x.size()
166
+ if bsz * length ==1:
167
+ return self.single_forward(x)
168
+ else:
169
+ return self.batch_forward(x)
170
+
171
+ def batch_map(self, x):
172
+ """
173
+ Map input through the mixture of experts layer.
174
+
175
+ Args:
176
+ x (Tensor): Input tensor.
177
+
178
+ Returns:
179
+ Tensor: Output tensor.
180
+ """
181
+ bsz, length, emb_size = x.size()
182
+ x = x.reshape(-1, emb_size)
183
+ loss = self.compute_gate(x)
184
+
185
+ y = self.input_linear(
186
+ x, self.top_k,
187
+ self.sorted_expert_idxs, self.sorted_scattered_idxs,
188
+ self.padded_block_idxs, self.expert_offsets,
189
+ )
190
+ y = y.view(bsz, length, self.top_k, -1)
191
+ return y, loss
192
+
193
+ def single_map(self, x):
194
+ bsz, length, emb_size = x.size()
195
+
196
+ x = x.reshape(1, self.input_size)
197
+ self.top_k_indices, self.top_k_gates = self.router(x)
198
+ loss = self.router.loss
199
+
200
+ y_list = []
201
+ for i in range(self.top_k):
202
+ expert_idx = self.top_k_indices[0,i]
203
+ y = F.linear(x, self.input_linear.weight[expert_idx])
204
+ y_list.append(y)
205
+ y = torch.cat(y_list, dim=0)
206
+ y = y.view(bsz, length, self.top_k, -1)
207
+ return y, loss
208
+
209
+ def map(self, x):
210
+ """
211
+ Map input through the mixture of experts layer.
212
+
213
+ Args:
214
+ x (Tensor): Input tensor.
215
+
216
+ Returns:
217
+ Tensor: Output tensor.
218
+ """
219
+ bsz, length, emb_size = x.size()
220
+ if bsz * length ==1:
221
+ return self.single_map(x)
222
+ else:
223
+ return self.batch_map(x)
224
+
225
+ def batch_reduce(self, x):
226
+ """
227
+ Reduce the mapped output.
228
+
229
+ Args:
230
+ x (Tensor): Mapped output tensor.
231
+
232
+ Returns:
233
+ Tensor: Reduced output tensor.
234
+ """
235
+
236
+ bsz, length, k, emb_size = x.size()
237
+ assert k == self.top_k
238
+ x = x.reshape(-1, emb_size)
239
+
240
+ y = self.output_linear(
241
+ x, 1,
242
+ self.sorted_expert_idxs, self.sorted_scattered_idxs,
243
+ self.padded_block_idxs, self.expert_offsets,
244
+ gates=self.top_k_gates,
245
+ )
246
+ y = y.view(bsz, length, self.input_size)
247
+ return y
248
+
249
+ def single_reduce(self, x):
250
+ bsz, length, k, emb_size = x.size()
251
+
252
+ x = x.reshape(k, emb_size)
253
+
254
+ y_list = []
255
+ for i in range(self.top_k):
256
+ expert_idx = self.top_k_indices[0,i]
257
+ y = F.linear(x[i], self.output_linear.weight[expert_idx]) * self.top_k_gates[0,i]
258
+ y_list.append(y)
259
+ y = sum(y_list)
260
+ y = y.view(bsz, length, self.input_size)
261
+ return y
262
+
263
+ def reduce(self, x):
264
+ """
265
+ Reduce the mapped output.
266
+
267
+ Args:
268
+ x (Tensor): Mapped output tensor.
269
+
270
+ Returns:
271
+ Tensor: Reduced output tensor.
272
+ """
273
+ bsz, length, k, emb_size = x.size()
274
+ if bsz * length ==1:
275
+ return self.single_reduce(x)
276
+ else:
277
+ return self.batch_reduce(x)