|
import math |
|
from typing import List |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
import scattermoe |
|
from .gate import top_k_gating |
|
|
|
|
|
class MoE(nn.Module): |
|
""" |
|
A Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts. |
|
|
|
|
|
Args: |
|
input_size: integer - size of the input |
|
head_size: integer - size of the expert's hidden layer |
|
num_experts: an integer - number of experts |
|
top_k: an integer - how many experts to use for each batch element |
|
bias: a boolean - whether to include bias in linear layers |
|
activation: an activation function to apply to expert's outputs |
|
acc_aux_loss: a boolean - whether to accumulate auxiliary loss |
|
hidden_size: an integer - hidden size of the experts |
|
gating_dropout: a float - dropout rate for gating network |
|
sample_topk: an integer - how many experts to sample during training |
|
gating_size: an integer - size of the gating network |
|
aux_loss: a string - type of auxiliary loss ('mi' or 'sparse') |
|
gate_type: a string - type of gating mechanism ('mlp' or 'topk') |
|
""" |
|
|
|
def __init__( |
|
self, |
|
input_size, |
|
hidden_size, |
|
num_experts, |
|
top_k, |
|
bias=True, |
|
activation=None, |
|
glu=True, |
|
): |
|
super(MoE, self).__init__() |
|
|
|
self.num_experts = num_experts |
|
self.input_size = input_size |
|
self.glu = glu |
|
if bias: |
|
self.bias = torch.nn.Parameter(torch.empty(input_size)) |
|
torch.nn.init.zeros_(self.bias) |
|
else: |
|
self.bias = None |
|
self.input_linear = scattermoe.parallel_experts.ParallelExperts(num_experts, input_size, hidden_size * 2 if glu else hidden_size) |
|
self.output_linear = scattermoe.parallel_experts.ParallelExperts(num_experts, hidden_size, input_size) |
|
self.top_k = min(top_k, self.num_experts) |
|
self.activation = activation |
|
|
|
self.router = top_k_gating( |
|
input_size=input_size, |
|
num_experts=num_experts, |
|
top_k=top_k, |
|
) |
|
|
|
def extra_repr(self): |
|
return 'k={}, e={}'.format( |
|
self.top_k, self.num_experts) |
|
|
|
def get_aux_loss_and_clear(self): |
|
""" |
|
Get the accumulated auxiliary loss and clear it. |
|
|
|
Returns: |
|
float: Accumulated auxiliary loss. |
|
""" |
|
|
|
return self.gate.get_aux_loss_and_clear() |
|
|
|
def compute_gate(self, x): |
|
top_k_indices, self.top_k_gates = self.router(x) |
|
|
|
with torch.no_grad(): |
|
self.sorted_expert_idxs, self.sorted_scattered_idxs = scattermoe.kernels.ops.flatten_and_sort(top_k_indices) |
|
self.padded_block_idxs, self.expert_offsets = scattermoe.kernels.ops.padded_block_indices(self.sorted_expert_idxs, self.num_experts) |
|
|
|
return self.router.loss |
|
|
|
def batch_forward(self, x): |
|
""" |
|
Forward pass of the mixture of experts layer. |
|
|
|
Args: |
|
x (Tensor): Input tensor. |
|
|
|
Returns: |
|
Tensor: Output tensor. |
|
""" |
|
bsz, length, emb_size = x.size() |
|
x = x.reshape(-1, emb_size) |
|
|
|
loss = self.compute_gate(x) |
|
|
|
h = self.input_linear( |
|
x, self.top_k, |
|
self.sorted_expert_idxs, self.sorted_scattered_idxs, |
|
self.padded_block_idxs, self.expert_offsets, |
|
grouped_out=True |
|
) |
|
|
|
if self.glu: |
|
h, g = h.chunk(2, dim=-1) |
|
h = self.activation(h) * g |
|
else: |
|
h = self.activation(h) |
|
|
|
y = self.output_linear( |
|
h, 1, |
|
self.sorted_expert_idxs, self.sorted_scattered_idxs, |
|
self.padded_block_idxs, self.expert_offsets, |
|
grouped_in=True, |
|
gates=self.top_k_gates, |
|
) |
|
|
|
y = y.view(bsz, length, self.input_size) |
|
if self.bias is not None: |
|
y = y + self.bias |
|
return y, loss |
|
|
|
def single_forward(self, x): |
|
bsz, length, emb_size = x.size() |
|
|
|
x = x.reshape(1, self.input_size) |
|
top_k_indices, top_k_gates = self.router(x) |
|
loss = self.router.loss |
|
|
|
y_list = [] |
|
for i in range(self.top_k): |
|
expert_idx = top_k_indices[0,i] |
|
|
|
h = F.linear(x, self.input_linear.weight[expert_idx]) |
|
if self.glu: |
|
h, g = h.chunk(2, dim=-1) |
|
h = self.activation(h) * g |
|
else: |
|
h = self.activation(h) |
|
y = F.linear(h, self.output_linear.weight[expert_idx]) * top_k_gates[0,i] |
|
|
|
y_list.append(y) |
|
|
|
y = sum(y_list) |
|
y = y.view(bsz, length, self.input_size) |
|
if self.bias is not None: |
|
y = y + self.bias |
|
return y, loss |
|
|
|
def forward(self, x): |
|
""" |
|
Forward pass of the mixture of experts layer. |
|
|
|
Args: |
|
x (Tensor): Input tensor. |
|
|
|
Returns: |
|
Tensor: Output tensor. |
|
""" |
|
bsz, length, emb_size = x.size() |
|
if bsz * length ==1: |
|
return self.single_forward(x) |
|
else: |
|
return self.batch_forward(x) |
|
|
|
def batch_map(self, x): |
|
""" |
|
Map input through the mixture of experts layer. |
|
|
|
Args: |
|
x (Tensor): Input tensor. |
|
|
|
Returns: |
|
Tensor: Output tensor. |
|
""" |
|
bsz, length, emb_size = x.size() |
|
x = x.reshape(-1, emb_size) |
|
loss = self.compute_gate(x) |
|
|
|
y = self.input_linear( |
|
x, self.top_k, |
|
self.sorted_expert_idxs, self.sorted_scattered_idxs, |
|
self.padded_block_idxs, self.expert_offsets, |
|
) |
|
y = y.view(bsz, length, self.top_k, -1) |
|
return y, loss |
|
|
|
def single_map(self, x): |
|
bsz, length, emb_size = x.size() |
|
|
|
x = x.reshape(1, self.input_size) |
|
self.top_k_indices, self.top_k_gates = self.router(x) |
|
loss = self.router.loss |
|
|
|
y_list = [] |
|
for i in range(self.top_k): |
|
expert_idx = self.top_k_indices[0,i] |
|
y = F.linear(x, self.input_linear.weight[expert_idx]) |
|
y_list.append(y) |
|
y = torch.cat(y_list, dim=0) |
|
y = y.view(bsz, length, self.top_k, -1) |
|
return y, loss |
|
|
|
def map(self, x): |
|
""" |
|
Map input through the mixture of experts layer. |
|
|
|
Args: |
|
x (Tensor): Input tensor. |
|
|
|
Returns: |
|
Tensor: Output tensor. |
|
""" |
|
bsz, length, emb_size = x.size() |
|
if bsz * length ==1: |
|
return self.single_map(x) |
|
else: |
|
return self.batch_map(x) |
|
|
|
def batch_reduce(self, x): |
|
""" |
|
Reduce the mapped output. |
|
|
|
Args: |
|
x (Tensor): Mapped output tensor. |
|
|
|
Returns: |
|
Tensor: Reduced output tensor. |
|
""" |
|
|
|
bsz, length, k, emb_size = x.size() |
|
assert k == self.top_k |
|
x = x.reshape(-1, emb_size) |
|
|
|
y = self.output_linear( |
|
x, 1, |
|
self.sorted_expert_idxs, self.sorted_scattered_idxs, |
|
self.padded_block_idxs, self.expert_offsets, |
|
gates=self.top_k_gates, |
|
) |
|
y = y.view(bsz, length, self.input_size) |
|
return y |
|
|
|
def single_reduce(self, x): |
|
bsz, length, k, emb_size = x.size() |
|
|
|
x = x.reshape(k, emb_size) |
|
|
|
y_list = [] |
|
for i in range(self.top_k): |
|
expert_idx = self.top_k_indices[0,i] |
|
y = F.linear(x[i], self.output_linear.weight[expert_idx]) * self.top_k_gates[0,i] |
|
y_list.append(y) |
|
y = sum(y_list) |
|
y = y.view(bsz, length, self.input_size) |
|
return y |
|
|
|
def reduce(self, x): |
|
""" |
|
Reduce the mapped output. |
|
|
|
Args: |
|
x (Tensor): Mapped output tensor. |
|
|
|
Returns: |
|
Tensor: Reduced output tensor. |
|
""" |
|
bsz, length, k, emb_size = x.size() |
|
if bsz * length ==1: |
|
return self.single_reduce(x) |
|
else: |
|
return self.batch_reduce(x) |