|
import torch |
|
from torch.nn import functional as F |
|
|
|
|
|
def moe_matmul(inputs, weight_list, group_index, linear_fn=lambda x, y: torch.matmul(x, y)): |
|
""" |
|
inputs: tensor (bs, sl, dim) |
|
weight_list: MoE weights, list of [(dim, dim')] |
|
group_index: (bs, sl), max(group_index) + 1 == len(weight_list), 在sl维上表示分组信息 |
|
group_nums: 表示MoE的个数 |
|
example: |
|
拉平后bs*sl的group index 0 0 0 1 1 1 0 0 1 1 1 0 0 0 1 1 1 (17) |
|
按0, 1 分别正反编码index |
|
0: |
|
cumsum: 0 1 2 2 2 2 3 4 4 4 4 5 6 7 7 7 7 |
|
offset: same |
|
mask: 0 1 2 0 0 0 3 4 0 0 0 5 6 7 0 0 0 |
|
new offset is 7 |
|
1: |
|
cumsum: 0 0 0 1 2 3 3 3 4 5 6 6 6 6 7 8 9 |
|
offset: 7 7 7 8 9 10 10 10 11 12 13 13 13 13 14 15 16 |
|
mask: 0 0 0 8 9 10 0 0 11 12 13 0 0 0 14 15 16 |
|
new offset is 16 |
|
... |
|
合并encode映射码表 |
|
0 1 2 8 9 10 3 4 11 12 13 5 6 7 14 15 16 |
|
执行gather操作,之后将inputs按offset split 分别matmul 再concat |
|
decode映射码表 |
|
0 1 2 8 9 10 3 4 11 12 13 5 6 7 14 15 16 index |
|
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 value |
|
: |
|
0 1 2 6 7 11 12 13 3 4 5 8 9 10 14 15 16 |
|
|
|
""" |
|
bs, sl = group_index.size() |
|
group_inputs, cur_offset, group_encode_index = [], 0, 0 |
|
for group_i in range(len(weight_list)): |
|
group_i_mask = torch.eq(group_index.to(torch.int32), group_i).view(bs * sl) |
|
group_inputs.append(linear_fn( |
|
torch.masked_select(inputs, group_i_mask.view(bs, sl, 1)).view(-1, inputs.size(-1)), |
|
weight_list[group_i])) |
|
group_i_index = torch.cumsum(group_i_mask.view(bs * sl).to(torch.int64), axis=0) |
|
group_i_index -= 1 if group_i == 0 else 0 |
|
group_i_index = (cur_offset + group_i_index) * group_i_mask |
|
cur_offset = torch.max(group_i_index) |
|
group_encode_index += group_i_index |
|
|
|
group_decode_index = torch.gather(torch.arange(0, bs * sl, step=1, dtype=torch.int64, device=inputs.device), 0, group_encode_index) |
|
group_inputs = torch.cat(group_inputs, axis=0) |
|
outputs = torch.index_select(group_inputs, 0, group_decode_index).view(bs, sl, -1) |
|
return outputs |
|
|
|
|
|
if __name__ == "__main__": |
|
bs, sl, d = 13, 997, 97 |
|
dtype = torch.bfloat16 |
|
inputs = torch.tensor(torch.randn([bs, sl, d], dtype=dtype).cuda(), requires_grad=True) |
|
group_num = 2 |
|
|
|
group_index = torch.remainder(torch.randint(0, 6, (bs, sl)), 1).cuda() |
|
weights = [torch.tensor(torch.eye(d).cuda().to(dtype), requires_grad=True) for _ in range(group_num)] |
|
output = moe_matmul(inputs, weights, group_index) |
|
print(inputs - output) |
|
loss = torch.sum(output * (group_index+1).to(dtype).view(bs, sl, 1)) |
|
print(loss) |
|
loss.backward() |
|
print(inputs.grad[:, :, 0] - group_index.to(dtype)) |
|
print(weights[-1].grad) |
|
|
|
|