jetmoe-8b-chat / moe.py
Guo
debug
e815555
raw
history blame
8.12 kB
import math
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
import scattermoe
from .gate import top_k_gating
class MoE(nn.Module):
"""
A Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts.
Args:
input_size: integer - size of the input
head_size: integer - size of the expert's hidden layer
num_experts: an integer - number of experts
top_k: an integer - how many experts to use for each batch element
bias: a boolean - whether to include bias in linear layers
activation: an activation function to apply to expert's outputs
acc_aux_loss: a boolean - whether to accumulate auxiliary loss
hidden_size: an integer - hidden size of the experts
gating_dropout: a float - dropout rate for gating network
sample_topk: an integer - how many experts to sample during training
gating_size: an integer - size of the gating network
aux_loss: a string - type of auxiliary loss ('mi' or 'sparse')
gate_type: a string - type of gating mechanism ('mlp' or 'topk')
"""
def __init__(
self,
input_size,
hidden_size,
num_experts,
top_k,
bias=True,
activation=None,
glu=True,
):
super(MoE, self).__init__()
self.num_experts = num_experts
self.input_size = input_size
self.glu = glu
if bias:
self.bias = torch.nn.Parameter(torch.empty(input_size))
torch.nn.init.zeros_(self.bias)
else:
self.bias = None
self.input_linear = scattermoe.parallel_experts.ParallelExperts(num_experts, input_size, hidden_size * 2 if glu else hidden_size)
self.output_linear = scattermoe.parallel_experts.ParallelExperts(num_experts, hidden_size, input_size)
self.top_k = min(top_k, self.num_experts)
self.activation = activation
self.router = top_k_gating(
input_size=input_size,
num_experts=num_experts,
top_k=top_k,
)
def extra_repr(self):
return 'k={}, e={}'.format(
self.top_k, self.num_experts)
def get_aux_loss_and_clear(self):
"""
Get the accumulated auxiliary loss and clear it.
Returns:
float: Accumulated auxiliary loss.
"""
return self.gate.get_aux_loss_and_clear()
def compute_gate(self, x):
top_k_indices, self.top_k_gates = self.router(x)
with torch.no_grad():
self.sorted_expert_idxs, self.sorted_scattered_idxs = scattermoe.kernels.ops.flatten_and_sort(top_k_indices)
self.padded_block_idxs, self.expert_offsets = scattermoe.kernels.ops.padded_block_indices(self.sorted_expert_idxs, self.num_experts)
return self.router.loss
def batch_forward(self, x):
"""
Forward pass of the mixture of experts layer.
Args:
x (Tensor): Input tensor.
Returns:
Tensor: Output tensor.
"""
bsz, length, emb_size = x.size()
x = x.reshape(-1, emb_size)
loss = self.compute_gate(x)
h = self.input_linear(
x, self.top_k,
self.sorted_expert_idxs, self.sorted_scattered_idxs,
self.padded_block_idxs, self.expert_offsets,
grouped_out=True
)
if self.glu:
h, g = h.chunk(2, dim=-1)
h = self.activation(h) * g
else:
h = self.activation(h)
y = self.output_linear(
h, 1,
self.sorted_expert_idxs, self.sorted_scattered_idxs,
self.padded_block_idxs, self.expert_offsets,
grouped_in=True,
gates=self.top_k_gates,
)
y = y.view(bsz, length, self.input_size)
if self.bias is not None:
y = y + self.bias
return y, loss
def single_forward(self, x):
bsz, length, emb_size = x.size()
x = x.reshape(1, self.input_size)
top_k_indices, top_k_gates = self.router(x)
loss = self.router.loss
y_list = []
for i in range(self.top_k):
expert_idx = top_k_indices[0,i]
h = F.linear(x, self.input_linear.weight[expert_idx])
if self.glu:
h, g = h.chunk(2, dim=-1)
h = self.activation(h) * g
else:
h = self.activation(h)
y = F.linear(h, self.output_linear.weight[expert_idx]) * top_k_gates[0,i]
y_list.append(y)
y = sum(y_list)
y = y.view(bsz, length, self.input_size)
if self.bias is not None:
y = y + self.bias
return y, loss
def forward(self, x):
"""
Forward pass of the mixture of experts layer.
Args:
x (Tensor): Input tensor.
Returns:
Tensor: Output tensor.
"""
bsz, length, emb_size = x.size()
if bsz * length ==1:
return self.single_forward(x)
else:
return self.batch_forward(x)
def batch_map(self, x):
"""
Map input through the mixture of experts layer.
Args:
x (Tensor): Input tensor.
Returns:
Tensor: Output tensor.
"""
bsz, length, emb_size = x.size()
x = x.reshape(-1, emb_size)
loss = self.compute_gate(x)
y = self.input_linear(
x, self.top_k,
self.sorted_expert_idxs, self.sorted_scattered_idxs,
self.padded_block_idxs, self.expert_offsets,
)
y = y.view(bsz, length, self.top_k, -1)
return y, loss
def single_map(self, x):
bsz, length, emb_size = x.size()
x = x.reshape(1, self.input_size)
self.top_k_indices, self.top_k_gates = self.router(x)
loss = self.router.loss
y_list = []
for i in range(self.top_k):
expert_idx = self.top_k_indices[0,i]
y = F.linear(x, self.input_linear.weight[expert_idx])
y_list.append(y)
y = torch.cat(y_list, dim=0)
y = y.view(bsz, length, self.top_k, -1)
return y, loss
def map(self, x):
"""
Map input through the mixture of experts layer.
Args:
x (Tensor): Input tensor.
Returns:
Tensor: Output tensor.
"""
bsz, length, emb_size = x.size()
if bsz * length ==1:
return self.single_map(x)
else:
return self.batch_map(x)
def batch_reduce(self, x):
"""
Reduce the mapped output.
Args:
x (Tensor): Mapped output tensor.
Returns:
Tensor: Reduced output tensor.
"""
bsz, length, k, emb_size = x.size()
assert k == self.top_k
x = x.reshape(-1, emb_size)
y = self.output_linear(
x, 1,
self.sorted_expert_idxs, self.sorted_scattered_idxs,
self.padded_block_idxs, self.expert_offsets,
gates=self.top_k_gates,
)
y = y.view(bsz, length, self.input_size)
return y
def single_reduce(self, x):
bsz, length, k, emb_size = x.size()
x = x.reshape(k, emb_size)
y_list = []
for i in range(self.top_k):
expert_idx = self.top_k_indices[0,i]
y = F.linear(x[i], self.output_linear.weight[expert_idx]) * self.top_k_gates[0,i]
y_list.append(y)
y = sum(y_list)
y = y.view(bsz, length, self.input_size)
return y
def reduce(self, x):
"""
Reduce the mapped output.
Args:
x (Tensor): Mapped output tensor.
Returns:
Tensor: Reduced output tensor.
"""
bsz, length, k, emb_size = x.size()
if bsz * length ==1:
return self.single_reduce(x)
else:
return self.batch_reduce(x)