Guo commited on
Commit
dd9f628
1 Parent(s): 52cca48
Files changed (3) hide show
  1. gate.py +0 -100
  2. modeling_jetmoe.py +364 -1
  3. moe.py +0 -277
gate.py DELETED
@@ -1,100 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
- class top_k_gating(nn.Module):
6
- def __init__(
7
- self,
8
- input_size,
9
- num_experts,
10
- top_k,
11
- ):
12
- """
13
- Initialize the top-k gating mechanism.
14
-
15
- Args:
16
- input_size (int): Size of the input.
17
- num_experts (int): Number of experts.
18
- top_k (int): Number of top experts to select.
19
- acc_aux_loss (bool): Whether to accumulate auxiliary loss statistics.
20
- dropout (float): Dropout rate for gating network.
21
- hidden_size (int): Hidden size of the gating network.
22
- sample_topk (int): Number of top-k experts to sample during training.
23
- aux_loss (str): Type of auxiliary loss ('mi' or 'switch').
24
- gate_type (str): Type of gating mechanism ('mlp', 'linear', or 'gmm').
25
- """
26
- super().__init__()
27
-
28
- self.num_experts = num_experts
29
- self.input_size = input_size
30
- assert top_k <= num_experts
31
- self.top_k = top_k
32
-
33
- self.layer = nn.Linear(input_size, num_experts, bias=False)
34
-
35
- def extra_repr(self):
36
- """
37
- Return extra representation string for the module.
38
- """
39
- return 'k={}, num_experts={}'.format(
40
- self.top_k, self.num_experts)
41
-
42
- def compute_aux_loss(self, probs, logits, gates):
43
- """
44
- Calculate and return the auxiliary loss based on the accumulated statistics.
45
-
46
- Args:
47
- eps (float): Small epsilon value for numerical stability.
48
-
49
- Returns:
50
- torch.Tensor: The calculated auxiliary loss.
51
- """
52
- count = logits.size(0)
53
- probs = probs.sum(0)
54
- freq = (gates > 0).float().sum(0)
55
- lsesq = (torch.log(torch.exp(logits).sum(dim=-1)) ** 2).sum()
56
-
57
- switchloss = self.num_experts * (
58
- F.normalize(probs, p=1, dim=0) *
59
- F.normalize(freq, p=1, dim=0)
60
- ).sum()
61
- zloss = lsesq / count
62
- loss = switchloss + 0.1 * zloss
63
-
64
- return loss
65
-
66
- def forward(self, x):
67
- """
68
- Compute the top-k gating for the input.
69
-
70
- See paper: https://arxiv.org/abs/1701.06538.
71
-
72
- Args:
73
- x (torch.Tensor): Input tensor with shape [batch_size, input_size].
74
- skip_mask (torch.Tensor): Skip mask tensor (binary) with the same shape as `x`.
75
- x: input Tensor with shape [batch_size, input_size]
76
- train: a boolean - we only add noise at training time.
77
- noise_epsilon: a float
78
-
79
- Returns:
80
- torch.Tensor: Top-k indices.
81
- torch.Tensor: Top-k gating values.
82
- torch.Tensor: Probability values for each expert.
83
- gates: a Tensor with shape [batch_size, num_experts]
84
- load: a Tensor with shape [num_experts]
85
- """
86
-
87
- logits = self.layer(x).float()
88
- top_k_logits, top_k_indices = logits.topk(self.top_k, dim=1)
89
- top_k_gates = torch.softmax(top_k_logits, dim=1).type_as(x)
90
-
91
- if self.training:
92
- probs = torch.softmax(logits, dim=1)
93
- zeros = torch.zeros_like(probs)
94
- zeros = zeros.to(top_k_gates.dtype) # Convert zeros to match top_k_gates dtype
95
- gates = zeros.scatter(1, top_k_indices, top_k_gates)
96
- self.loss = self.compute_aux_loss(probs, logits, gates)
97
- else:
98
- self.loss = 0
99
-
100
- return top_k_indices, top_k_gates
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modeling_jetmoe.py CHANGED
@@ -27,7 +27,7 @@ from transformers.utils import (
27
  from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
28
  from transformers.cache_utils import Cache, DynamicCache
29
  from .configuration_jetmoe import JetMoEConfig
30
- from . import moe
31
 
32
  try:
33
  if is_flash_attn_2_available():
@@ -43,6 +43,369 @@ logger = logging.get_logger(__name__)
43
  _CHECKPOINT_FOR_DOC = "jetmoe"
44
  _CONFIG_FOR_DOC = "JetMoEConfig"
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  @dataclass
48
  class JetMoEBaseModelOutputWithPast(BaseModelOutputWithPast):
 
27
  from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
28
  from transformers.cache_utils import Cache, DynamicCache
29
  from .configuration_jetmoe import JetMoEConfig
30
+ import scattermoe
31
 
32
  try:
33
  if is_flash_attn_2_available():
 
43
  _CHECKPOINT_FOR_DOC = "jetmoe"
44
  _CONFIG_FOR_DOC = "JetMoEConfig"
45
 
46
+ class top_k_gating(nn.Module):
47
+ def __init__(
48
+ self,
49
+ input_size,
50
+ num_experts,
51
+ top_k,
52
+ ):
53
+ """
54
+ Initialize the top-k gating mechanism.
55
+
56
+ Args:
57
+ input_size (int): Size of the input.
58
+ num_experts (int): Number of experts.
59
+ top_k (int): Number of top experts to select.
60
+ acc_aux_loss (bool): Whether to accumulate auxiliary loss statistics.
61
+ dropout (float): Dropout rate for gating network.
62
+ hidden_size (int): Hidden size of the gating network.
63
+ sample_topk (int): Number of top-k experts to sample during training.
64
+ aux_loss (str): Type of auxiliary loss ('mi' or 'switch').
65
+ gate_type (str): Type of gating mechanism ('mlp', 'linear', or 'gmm').
66
+ """
67
+ super().__init__()
68
+
69
+ self.num_experts = num_experts
70
+ self.input_size = input_size
71
+ assert top_k <= num_experts
72
+ self.top_k = top_k
73
+
74
+ self.layer = nn.Linear(input_size, num_experts, bias=False)
75
+
76
+ def extra_repr(self):
77
+ """
78
+ Return extra representation string for the module.
79
+ """
80
+ return 'k={}, num_experts={}'.format(
81
+ self.top_k, self.num_experts)
82
+
83
+ def compute_aux_loss(self, probs, logits, gates):
84
+ """
85
+ Calculate and return the auxiliary loss based on the accumulated statistics.
86
+
87
+ Args:
88
+ eps (float): Small epsilon value for numerical stability.
89
+
90
+ Returns:
91
+ torch.Tensor: The calculated auxiliary loss.
92
+ """
93
+ count = logits.size(0)
94
+ probs = probs.sum(0)
95
+ freq = (gates > 0).float().sum(0)
96
+ lsesq = (torch.log(torch.exp(logits).sum(dim=-1)) ** 2).sum()
97
+
98
+ switchloss = self.num_experts * (
99
+ F.normalize(probs, p=1, dim=0) *
100
+ F.normalize(freq, p=1, dim=0)
101
+ ).sum()
102
+ zloss = lsesq / count
103
+ loss = switchloss + 0.1 * zloss
104
+
105
+ return loss
106
+
107
+ def forward(self, x):
108
+ """
109
+ Compute the top-k gating for the input.
110
+
111
+ See paper: https://arxiv.org/abs/1701.06538.
112
+
113
+ Args:
114
+ x (torch.Tensor): Input tensor with shape [batch_size, input_size].
115
+ skip_mask (torch.Tensor): Skip mask tensor (binary) with the same shape as `x`.
116
+ x: input Tensor with shape [batch_size, input_size]
117
+ train: a boolean - we only add noise at training time.
118
+ noise_epsilon: a float
119
+
120
+ Returns:
121
+ torch.Tensor: Top-k indices.
122
+ torch.Tensor: Top-k gating values.
123
+ torch.Tensor: Probability values for each expert.
124
+ gates: a Tensor with shape [batch_size, num_experts]
125
+ load: a Tensor with shape [num_experts]
126
+ """
127
+
128
+ logits = self.layer(x).float()
129
+ top_k_logits, top_k_indices = logits.topk(self.top_k, dim=1)
130
+ top_k_gates = torch.softmax(top_k_logits, dim=1).type_as(x)
131
+
132
+ if self.training:
133
+ probs = torch.softmax(logits, dim=1)
134
+ zeros = torch.zeros_like(probs)
135
+ zeros = zeros.to(top_k_gates.dtype) # Convert zeros to match top_k_gates dtype
136
+ gates = zeros.scatter(1, top_k_indices, top_k_gates)
137
+ self.loss = self.compute_aux_loss(probs, logits, gates)
138
+ else:
139
+ self.loss = 0
140
+
141
+ return top_k_indices, top_k_gates
142
+
143
+ class MoE(nn.Module):
144
+ """
145
+ A Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts.
146
+
147
+
148
+ Args:
149
+ input_size: integer - size of the input
150
+ head_size: integer - size of the expert's hidden layer
151
+ num_experts: an integer - number of experts
152
+ top_k: an integer - how many experts to use for each batch element
153
+ bias: a boolean - whether to include bias in linear layers
154
+ activation: an activation function to apply to expert's outputs
155
+ acc_aux_loss: a boolean - whether to accumulate auxiliary loss
156
+ hidden_size: an integer - hidden size of the experts
157
+ gating_dropout: a float - dropout rate for gating network
158
+ sample_topk: an integer - how many experts to sample during training
159
+ gating_size: an integer - size of the gating network
160
+ aux_loss: a string - type of auxiliary loss ('mi' or 'sparse')
161
+ gate_type: a string - type of gating mechanism ('mlp' or 'topk')
162
+ """
163
+
164
+ def __init__(
165
+ self,
166
+ input_size,
167
+ hidden_size,
168
+ num_experts,
169
+ top_k,
170
+ bias=True,
171
+ activation=None,
172
+ glu=True,
173
+ ):
174
+ super(MoE, self).__init__()
175
+
176
+ self.num_experts = num_experts
177
+ self.input_size = input_size
178
+ self.glu = glu
179
+ if bias:
180
+ self.bias = torch.nn.Parameter(torch.empty(input_size))
181
+ torch.nn.init.zeros_(self.bias)
182
+ else:
183
+ self.bias = None
184
+ self.input_linear = scattermoe.parallel_experts.ParallelExperts(num_experts, input_size, hidden_size * 2 if glu else hidden_size)
185
+ self.output_linear = scattermoe.parallel_experts.ParallelExperts(num_experts, hidden_size, input_size)
186
+ self.top_k = min(top_k, self.num_experts)
187
+ self.activation = activation
188
+
189
+ self.router = top_k_gating(
190
+ input_size=input_size,
191
+ num_experts=num_experts,
192
+ top_k=top_k,
193
+ )
194
+
195
+ def extra_repr(self):
196
+ return 'k={}, e={}'.format(
197
+ self.top_k, self.num_experts)
198
+
199
+ def get_aux_loss_and_clear(self):
200
+ """
201
+ Get the accumulated auxiliary loss and clear it.
202
+
203
+ Returns:
204
+ float: Accumulated auxiliary loss.
205
+ """
206
+
207
+ return self.gate.get_aux_loss_and_clear()
208
+
209
+ def compute_gate(self, x):
210
+ top_k_indices, self.top_k_gates = self.router(x)
211
+
212
+ with torch.no_grad():
213
+ self.sorted_expert_idxs, self.sorted_scattered_idxs = scattermoe.kernels.ops.flatten_and_sort(top_k_indices)
214
+ self.padded_block_idxs, self.expert_offsets = scattermoe.kernels.ops.padded_block_indices(self.sorted_expert_idxs, self.num_experts)
215
+
216
+ return self.router.loss
217
+
218
+ def batch_forward(self, x):
219
+ """
220
+ Forward pass of the mixture of experts layer.
221
+
222
+ Args:
223
+ x (Tensor): Input tensor.
224
+
225
+ Returns:
226
+ Tensor: Output tensor.
227
+ """
228
+ bsz, length, emb_size = x.size()
229
+ x = x.reshape(-1, emb_size)
230
+
231
+ loss = self.compute_gate(x)
232
+
233
+ h = self.input_linear(
234
+ x, self.top_k,
235
+ self.sorted_expert_idxs, self.sorted_scattered_idxs,
236
+ self.padded_block_idxs, self.expert_offsets,
237
+ grouped_out=True
238
+ )
239
+
240
+ if self.glu:
241
+ h, g = h.chunk(2, dim=-1)
242
+ h = self.activation(h) * g
243
+ else:
244
+ h = self.activation(h)
245
+
246
+ y = self.output_linear(
247
+ h, 1,
248
+ self.sorted_expert_idxs, self.sorted_scattered_idxs,
249
+ self.padded_block_idxs, self.expert_offsets,
250
+ grouped_in=True,
251
+ gates=self.top_k_gates,
252
+ )
253
+
254
+ y = y.view(bsz, length, self.input_size)
255
+ if self.bias is not None:
256
+ y = y + self.bias
257
+ return y, loss
258
+
259
+ def single_forward(self, x):
260
+ bsz, length, emb_size = x.size()
261
+
262
+ x = x.reshape(1, self.input_size)
263
+ top_k_indices, top_k_gates = self.router(x)
264
+ loss = self.router.loss
265
+
266
+ y_list = []
267
+ for i in range(self.top_k):
268
+ expert_idx = top_k_indices[0,i]
269
+
270
+ h = F.linear(x, self.input_linear.weight[expert_idx])
271
+ if self.glu:
272
+ h, g = h.chunk(2, dim=-1)
273
+ h = self.activation(h) * g
274
+ else:
275
+ h = self.activation(h)
276
+ y = F.linear(h, self.output_linear.weight[expert_idx]) * top_k_gates[0,i]
277
+
278
+ y_list.append(y)
279
+
280
+ y = sum(y_list)
281
+ y = y.view(bsz, length, self.input_size)
282
+ if self.bias is not None:
283
+ y = y + self.bias
284
+ return y, loss
285
+
286
+ def forward(self, x):
287
+ """
288
+ Forward pass of the mixture of experts layer.
289
+
290
+ Args:
291
+ x (Tensor): Input tensor.
292
+
293
+ Returns:
294
+ Tensor: Output tensor.
295
+ """
296
+ bsz, length, emb_size = x.size()
297
+ if bsz * length ==1:
298
+ return self.single_forward(x)
299
+ else:
300
+ return self.batch_forward(x)
301
+
302
+ def batch_map(self, x):
303
+ """
304
+ Map input through the mixture of experts layer.
305
+
306
+ Args:
307
+ x (Tensor): Input tensor.
308
+
309
+ Returns:
310
+ Tensor: Output tensor.
311
+ """
312
+ bsz, length, emb_size = x.size()
313
+ x = x.reshape(-1, emb_size)
314
+ loss = self.compute_gate(x)
315
+
316
+ y = self.input_linear(
317
+ x, self.top_k,
318
+ self.sorted_expert_idxs, self.sorted_scattered_idxs,
319
+ self.padded_block_idxs, self.expert_offsets,
320
+ )
321
+ y = y.view(bsz, length, self.top_k, -1)
322
+ return y, loss
323
+
324
+ def single_map(self, x):
325
+ bsz, length, emb_size = x.size()
326
+
327
+ x = x.reshape(1, self.input_size)
328
+ self.top_k_indices, self.top_k_gates = self.router(x)
329
+ loss = self.router.loss
330
+
331
+ y_list = []
332
+ for i in range(self.top_k):
333
+ expert_idx = self.top_k_indices[0,i]
334
+ y = F.linear(x, self.input_linear.weight[expert_idx])
335
+ y_list.append(y)
336
+ y = torch.cat(y_list, dim=0)
337
+ y = y.view(bsz, length, self.top_k, -1)
338
+ return y, loss
339
+
340
+ def map(self, x):
341
+ """
342
+ Map input through the mixture of experts layer.
343
+
344
+ Args:
345
+ x (Tensor): Input tensor.
346
+
347
+ Returns:
348
+ Tensor: Output tensor.
349
+ """
350
+ bsz, length, emb_size = x.size()
351
+ if bsz * length ==1:
352
+ return self.single_map(x)
353
+ else:
354
+ return self.batch_map(x)
355
+
356
+ def batch_reduce(self, x):
357
+ """
358
+ Reduce the mapped output.
359
+
360
+ Args:
361
+ x (Tensor): Mapped output tensor.
362
+
363
+ Returns:
364
+ Tensor: Reduced output tensor.
365
+ """
366
+
367
+ bsz, length, k, emb_size = x.size()
368
+ assert k == self.top_k
369
+ x = x.reshape(-1, emb_size)
370
+
371
+ y = self.output_linear(
372
+ x, 1,
373
+ self.sorted_expert_idxs, self.sorted_scattered_idxs,
374
+ self.padded_block_idxs, self.expert_offsets,
375
+ gates=self.top_k_gates,
376
+ )
377
+ y = y.view(bsz, length, self.input_size)
378
+ return y
379
+
380
+ def single_reduce(self, x):
381
+ bsz, length, k, emb_size = x.size()
382
+
383
+ x = x.reshape(k, emb_size)
384
+
385
+ y_list = []
386
+ for i in range(self.top_k):
387
+ expert_idx = self.top_k_indices[0,i]
388
+ y = F.linear(x[i], self.output_linear.weight[expert_idx]) * self.top_k_gates[0,i]
389
+ y_list.append(y)
390
+ y = sum(y_list)
391
+ y = y.view(bsz, length, self.input_size)
392
+ return y
393
+
394
+ def reduce(self, x):
395
+ """
396
+ Reduce the mapped output.
397
+
398
+ Args:
399
+ x (Tensor): Mapped output tensor.
400
+
401
+ Returns:
402
+ Tensor: Reduced output tensor.
403
+ """
404
+ bsz, length, k, emb_size = x.size()
405
+ if bsz * length ==1:
406
+ return self.single_reduce(x)
407
+ else:
408
+ return self.batch_reduce(x)
409
 
410
  @dataclass
411
  class JetMoEBaseModelOutputWithPast(BaseModelOutputWithPast):
moe.py DELETED
@@ -1,277 +0,0 @@
1
- import math
2
- from typing import List
3
-
4
- import torch
5
- import torch.nn as nn
6
- import torch.nn.functional as F
7
-
8
- import scattermoe
9
- from .gate import top_k_gating
10
-
11
-
12
- class MoE(nn.Module):
13
- """
14
- A Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts.
15
-
16
-
17
- Args:
18
- input_size: integer - size of the input
19
- head_size: integer - size of the expert's hidden layer
20
- num_experts: an integer - number of experts
21
- top_k: an integer - how many experts to use for each batch element
22
- bias: a boolean - whether to include bias in linear layers
23
- activation: an activation function to apply to expert's outputs
24
- acc_aux_loss: a boolean - whether to accumulate auxiliary loss
25
- hidden_size: an integer - hidden size of the experts
26
- gating_dropout: a float - dropout rate for gating network
27
- sample_topk: an integer - how many experts to sample during training
28
- gating_size: an integer - size of the gating network
29
- aux_loss: a string - type of auxiliary loss ('mi' or 'sparse')
30
- gate_type: a string - type of gating mechanism ('mlp' or 'topk')
31
- """
32
-
33
- def __init__(
34
- self,
35
- input_size,
36
- hidden_size,
37
- num_experts,
38
- top_k,
39
- bias=True,
40
- activation=None,
41
- glu=True,
42
- ):
43
- super(MoE, self).__init__()
44
-
45
- self.num_experts = num_experts
46
- self.input_size = input_size
47
- self.glu = glu
48
- if bias:
49
- self.bias = torch.nn.Parameter(torch.empty(input_size))
50
- torch.nn.init.zeros_(self.bias)
51
- else:
52
- self.bias = None
53
- self.input_linear = scattermoe.parallel_experts.ParallelExperts(num_experts, input_size, hidden_size * 2 if glu else hidden_size)
54
- self.output_linear = scattermoe.parallel_experts.ParallelExperts(num_experts, hidden_size, input_size)
55
- self.top_k = min(top_k, self.num_experts)
56
- self.activation = activation
57
-
58
- self.router = top_k_gating(
59
- input_size=input_size,
60
- num_experts=num_experts,
61
- top_k=top_k,
62
- )
63
-
64
- def extra_repr(self):
65
- return 'k={}, e={}'.format(
66
- self.top_k, self.num_experts)
67
-
68
- def get_aux_loss_and_clear(self):
69
- """
70
- Get the accumulated auxiliary loss and clear it.
71
-
72
- Returns:
73
- float: Accumulated auxiliary loss.
74
- """
75
-
76
- return self.gate.get_aux_loss_and_clear()
77
-
78
- def compute_gate(self, x):
79
- top_k_indices, self.top_k_gates = self.router(x)
80
-
81
- with torch.no_grad():
82
- self.sorted_expert_idxs, self.sorted_scattered_idxs = scattermoe.kernels.ops.flatten_and_sort(top_k_indices)
83
- self.padded_block_idxs, self.expert_offsets = scattermoe.kernels.ops.padded_block_indices(self.sorted_expert_idxs, self.num_experts)
84
-
85
- return self.router.loss
86
-
87
- def batch_forward(self, x):
88
- """
89
- Forward pass of the mixture of experts layer.
90
-
91
- Args:
92
- x (Tensor): Input tensor.
93
-
94
- Returns:
95
- Tensor: Output tensor.
96
- """
97
- bsz, length, emb_size = x.size()
98
- x = x.reshape(-1, emb_size)
99
-
100
- loss = self.compute_gate(x)
101
-
102
- h = self.input_linear(
103
- x, self.top_k,
104
- self.sorted_expert_idxs, self.sorted_scattered_idxs,
105
- self.padded_block_idxs, self.expert_offsets,
106
- grouped_out=True
107
- )
108
-
109
- if self.glu:
110
- h, g = h.chunk(2, dim=-1)
111
- h = self.activation(h) * g
112
- else:
113
- h = self.activation(h)
114
-
115
- y = self.output_linear(
116
- h, 1,
117
- self.sorted_expert_idxs, self.sorted_scattered_idxs,
118
- self.padded_block_idxs, self.expert_offsets,
119
- grouped_in=True,
120
- gates=self.top_k_gates,
121
- )
122
-
123
- y = y.view(bsz, length, self.input_size)
124
- if self.bias is not None:
125
- y = y + self.bias
126
- return y, loss
127
-
128
- def single_forward(self, x):
129
- bsz, length, emb_size = x.size()
130
-
131
- x = x.reshape(1, self.input_size)
132
- top_k_indices, top_k_gates = self.router(x)
133
- loss = self.router.loss
134
-
135
- y_list = []
136
- for i in range(self.top_k):
137
- expert_idx = top_k_indices[0,i]
138
-
139
- h = F.linear(x, self.input_linear.weight[expert_idx])
140
- if self.glu:
141
- h, g = h.chunk(2, dim=-1)
142
- h = self.activation(h) * g
143
- else:
144
- h = self.activation(h)
145
- y = F.linear(h, self.output_linear.weight[expert_idx]) * top_k_gates[0,i]
146
-
147
- y_list.append(y)
148
-
149
- y = sum(y_list)
150
- y = y.view(bsz, length, self.input_size)
151
- if self.bias is not None:
152
- y = y + self.bias
153
- return y, loss
154
-
155
- def forward(self, x):
156
- """
157
- Forward pass of the mixture of experts layer.
158
-
159
- Args:
160
- x (Tensor): Input tensor.
161
-
162
- Returns:
163
- Tensor: Output tensor.
164
- """
165
- bsz, length, emb_size = x.size()
166
- if bsz * length ==1:
167
- return self.single_forward(x)
168
- else:
169
- return self.batch_forward(x)
170
-
171
- def batch_map(self, x):
172
- """
173
- Map input through the mixture of experts layer.
174
-
175
- Args:
176
- x (Tensor): Input tensor.
177
-
178
- Returns:
179
- Tensor: Output tensor.
180
- """
181
- bsz, length, emb_size = x.size()
182
- x = x.reshape(-1, emb_size)
183
- loss = self.compute_gate(x)
184
-
185
- y = self.input_linear(
186
- x, self.top_k,
187
- self.sorted_expert_idxs, self.sorted_scattered_idxs,
188
- self.padded_block_idxs, self.expert_offsets,
189
- )
190
- y = y.view(bsz, length, self.top_k, -1)
191
- return y, loss
192
-
193
- def single_map(self, x):
194
- bsz, length, emb_size = x.size()
195
-
196
- x = x.reshape(1, self.input_size)
197
- self.top_k_indices, self.top_k_gates = self.router(x)
198
- loss = self.router.loss
199
-
200
- y_list = []
201
- for i in range(self.top_k):
202
- expert_idx = self.top_k_indices[0,i]
203
- y = F.linear(x, self.input_linear.weight[expert_idx])
204
- y_list.append(y)
205
- y = torch.cat(y_list, dim=0)
206
- y = y.view(bsz, length, self.top_k, -1)
207
- return y, loss
208
-
209
- def map(self, x):
210
- """
211
- Map input through the mixture of experts layer.
212
-
213
- Args:
214
- x (Tensor): Input tensor.
215
-
216
- Returns:
217
- Tensor: Output tensor.
218
- """
219
- bsz, length, emb_size = x.size()
220
- if bsz * length ==1:
221
- return self.single_map(x)
222
- else:
223
- return self.batch_map(x)
224
-
225
- def batch_reduce(self, x):
226
- """
227
- Reduce the mapped output.
228
-
229
- Args:
230
- x (Tensor): Mapped output tensor.
231
-
232
- Returns:
233
- Tensor: Reduced output tensor.
234
- """
235
-
236
- bsz, length, k, emb_size = x.size()
237
- assert k == self.top_k
238
- x = x.reshape(-1, emb_size)
239
-
240
- y = self.output_linear(
241
- x, 1,
242
- self.sorted_expert_idxs, self.sorted_scattered_idxs,
243
- self.padded_block_idxs, self.expert_offsets,
244
- gates=self.top_k_gates,
245
- )
246
- y = y.view(bsz, length, self.input_size)
247
- return y
248
-
249
- def single_reduce(self, x):
250
- bsz, length, k, emb_size = x.size()
251
-
252
- x = x.reshape(k, emb_size)
253
-
254
- y_list = []
255
- for i in range(self.top_k):
256
- expert_idx = self.top_k_indices[0,i]
257
- y = F.linear(x[i], self.output_linear.weight[expert_idx]) * self.top_k_gates[0,i]
258
- y_list.append(y)
259
- y = sum(y_list)
260
- y = y.view(bsz, length, self.input_size)
261
- return y
262
-
263
- def reduce(self, x):
264
- """
265
- Reduce the mapped output.
266
-
267
- Args:
268
- x (Tensor): Mapped output tensor.
269
-
270
- Returns:
271
- Tensor: Reduced output tensor.
272
- """
273
- bsz, length, k, emb_size = x.size()
274
- if bsz * length ==1:
275
- return self.single_reduce(x)
276
- else:
277
- return self.batch_reduce(x)