Ocean-OCR / moe.py
guoxy25's picture
Upload 56 files
c6dee39 verified
import torch
from torch.nn import functional as F
# @torch.compile
def moe_matmul(inputs, weight_list, group_index, linear_fn=lambda x, y: torch.matmul(x, y)):
"""
inputs: tensor (bs, sl, dim)
weight_list: MoE weights, list of [(dim, dim')]
group_index: (bs, sl), max(group_index) + 1 == len(weight_list), 在sl维上表示分组信息
group_nums: 表示MoE的个数
example:
拉平后bs*sl的group index 0 0 0 1 1 1 0 0 1 1 1 0 0 0 1 1 1 (17)
按0, 1 分别正反编码index
0:
cumsum: 0 1 2 2 2 2 3 4 4 4 4 5 6 7 7 7 7
offset: same
mask: 0 1 2 0 0 0 3 4 0 0 0 5 6 7 0 0 0
new offset is 7
1:
cumsum: 0 0 0 1 2 3 3 3 4 5 6 6 6 6 7 8 9
offset: 7 7 7 8 9 10 10 10 11 12 13 13 13 13 14 15 16
mask: 0 0 0 8 9 10 0 0 11 12 13 0 0 0 14 15 16
new offset is 16
...
合并encode映射码表
0 1 2 8 9 10 3 4 11 12 13 5 6 7 14 15 16
执行gather操作,之后将inputs按offset split 分别matmul 再concat
decode映射码表
0 1 2 8 9 10 3 4 11 12 13 5 6 7 14 15 16 index
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 value
:
0 1 2 6 7 11 12 13 3 4 5 8 9 10 14 15 16
"""
bs, sl = group_index.size()
group_inputs, cur_offset, group_encode_index = [], 0, 0
for group_i in range(len(weight_list)):
group_i_mask = torch.eq(group_index.to(torch.int32), group_i).view(bs * sl) # (bs * sl)
group_inputs.append(linear_fn(
torch.masked_select(inputs, group_i_mask.view(bs, sl, 1)).view(-1, inputs.size(-1)),
weight_list[group_i])) # (?, dims) X (dims, dims')
group_i_index = torch.cumsum(group_i_mask.view(bs * sl).to(torch.int64), axis=0)
group_i_index -= 1 if group_i == 0 else 0 # 下标从0开始 只需要在第一个分组处理
group_i_index = (cur_offset + group_i_index) * group_i_mask
cur_offset = torch.max(group_i_index)
group_encode_index += group_i_index
group_decode_index = torch.gather(torch.arange(0, bs * sl, step=1, dtype=torch.int64, device=inputs.device), 0, group_encode_index)
group_inputs = torch.cat(group_inputs, axis=0) # (bs * sl, dims')
outputs = torch.index_select(group_inputs, 0, group_decode_index).view(bs, sl, -1)
return outputs
if __name__ == "__main__":
bs, sl, d = 13, 997, 97
dtype = torch.bfloat16
inputs = torch.tensor(torch.randn([bs, sl, d], dtype=dtype).cuda(), requires_grad=True)
group_num = 2
# group_index = torch.remainder(torch.randint(0, 6, (bs, sl)), group_num).cuda()
group_index = torch.remainder(torch.randint(0, 6, (bs, sl)), 1).cuda()
weights = [torch.tensor(torch.eye(d).cuda().to(dtype), requires_grad=True) for _ in range(group_num)]
output = moe_matmul(inputs, weights, group_index)
print(inputs - output)
loss = torch.sum(output * (group_index+1).to(dtype).view(bs, sl, 1))
print(loss)
loss.backward()
print(inputs.grad[:, :, 0] - group_index.to(dtype))
print(weights[-1].grad)