File size: 3,107 Bytes
c6dee39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch
from torch.nn import functional as F

# @torch.compile
def moe_matmul(inputs, weight_list, group_index, linear_fn=lambda x, y: torch.matmul(x, y)):
    """
    inputs: tensor (bs, sl, dim)
    weight_list: MoE weights, list of [(dim, dim')]
    group_index: (bs, sl), max(group_index) + 1 == len(weight_list), 在sl维上表示分组信息
    group_nums: 表示MoE的个数
    example:
        拉平后bs*sl的group index 0 0 0 1 1 1 0 0 1 1 1 0 0 0 1 1 1  (17)
        按0, 1 分别正反编码index
        0: 
        cumsum: 0 1 2 2 2 2 3 4 4 4 4 5 6 7 7 7 7
        offset: same
        mask:   0 1 2 0 0 0 3 4 0 0 0 5 6 7 0 0 0
        new offset is 7
        1:
        cumsum: 0 0 0 1 2 3 3 3 4 5 6 6 6 6 7 8 9
        offset: 7 7 7 8 9 10 10 10 11 12 13 13 13 13 14 15 16 
        mask:   0 0 0 8 9 10 0 0 11 12 13 0 0 0 14 15 16
        new offset is 16
        ...
        合并encode映射码表
        0 1 2 8 9 10 3 4 11 12 13 5 6 7 14 15 16
        执行gather操作,之后将inputs按offset split 分别matmul 再concat
        decode映射码表
        0 1 2 8 9 10 3 4 11 12 13 5  6  7  14 15 16  index
        0 1 2 3 4 5  6 7 8  9  10 11 12 13 14 15 16  value
        :
        0 1 2 6 7 11 12 13 3 4 5 8 9 10 14 15 16

    """
    bs, sl = group_index.size()
    group_inputs, cur_offset, group_encode_index = [], 0, 0
    for group_i in range(len(weight_list)):
        group_i_mask = torch.eq(group_index.to(torch.int32), group_i).view(bs * sl)  # (bs * sl)
        group_inputs.append(linear_fn(
            torch.masked_select(inputs, group_i_mask.view(bs, sl, 1)).view(-1, inputs.size(-1)),
            weight_list[group_i]))  # (?, dims) X (dims, dims')
        group_i_index = torch.cumsum(group_i_mask.view(bs * sl).to(torch.int64), axis=0)
        group_i_index -= 1 if group_i == 0 else 0 # 下标从0开始 只需要在第一个分组处理
        group_i_index = (cur_offset + group_i_index) * group_i_mask
        cur_offset = torch.max(group_i_index)
        group_encode_index += group_i_index
    
    group_decode_index = torch.gather(torch.arange(0, bs * sl, step=1, dtype=torch.int64, device=inputs.device), 0, group_encode_index)
    group_inputs = torch.cat(group_inputs, axis=0)  # (bs * sl, dims')
    outputs = torch.index_select(group_inputs, 0, group_decode_index).view(bs, sl, -1)
    return outputs


if __name__ == "__main__":
    bs, sl, d = 13, 997, 97
    dtype = torch.bfloat16
    inputs = torch.tensor(torch.randn([bs, sl, d], dtype=dtype).cuda(), requires_grad=True)
    group_num = 2
    # group_index = torch.remainder(torch.randint(0, 6, (bs, sl)), group_num).cuda()
    group_index = torch.remainder(torch.randint(0, 6, (bs, sl)), 1).cuda()
    weights = [torch.tensor(torch.eye(d).cuda().to(dtype), requires_grad=True) for _ in range(group_num)]
    output = moe_matmul(inputs, weights, group_index)
    print(inputs - output)
    loss = torch.sum(output * (group_index+1).to(dtype).view(bs, sl, 1))
    print(loss)
    loss.backward()
    print(inputs.grad[:, :, 0] - group_index.to(dtype))
    print(weights[-1].grad)