Crystalcareai commited on
Commit
6317b4a
·
verified ·
1 Parent(s): da696ed

Delete single.py

Browse files
Files changed (1) hide show
  1. single.py +0 -60
single.py DELETED
@@ -1,60 +0,0 @@
1
- import torch
2
- import triton
3
- import triton.language as tl
4
- from torch.nn import functional as F
5
-
6
- @triton.jit
7
- def _single2scatter(
8
- X_ptr, stride_xm, stride_xk,
9
- W_ptr, stride_we, stride_wk, stride_wn,
10
- Y_ptr, stride_ym, stride_yn,
11
- expert_idxs_ptr,
12
- FAN_OUT: tl.constexpr,
13
- K: tl.constexpr, N: tl.constexpr, E: tl.constexpr,
14
- BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
15
- ACC_TYPE: tl.constexpr,
16
- ):
17
- pid0 = tl.program_id(axis=0)
18
- pid1 = tl.program_id(axis=1)
19
-
20
- N_block_id = pid0
21
- if FAN_OUT == 1:
22
- in_idx = pid1
23
- else:
24
- in_idx = 0
25
- out_idx = pid1
26
-
27
- K_block = tl.arange(0, BLOCK_K)
28
- N_block = tl.max_contiguous(tl.multiple_of((N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)) % N, BLOCK_N), BLOCK_N)
29
- E_idx = tl.load(expert_idxs_ptr + pid1)
30
- X_blk_ptrs = X_ptr + in_idx * stride_xm + K_block[:, None] * stride_xk
31
- W_blk_ptrs = W_ptr + E_idx * stride_we + K_block[:, None] * stride_wk + N_block[None, :] * stride_wn
32
- acc = tl.zeros((1, BLOCK_N), dtype=ACC_TYPE)
33
- for K_block_id in range(0, tl.cdiv(K, BLOCK_K)):
34
- x = tl.load(X_blk_ptrs)
35
- w = tl.load(W_blk_ptrs)
36
- acc += tl.sum(x * w, axis=0)[None, :]
37
- X_blk_ptrs += BLOCK_K * stride_xk
38
- W_blk_ptrs += BLOCK_K * stride_wk
39
- Y_blk_ptrs = Y_ptr + out_idx * stride_ym + N_block[None, :] * stride_yn
40
- tl.store(Y_blk_ptrs, acc)
41
-
42
- def single2scatter(X, W, expert_idxs):
43
- E, xdim, ydim = W.size()
44
- k = expert_idxs.size(1)
45
- assert X.size(0) == k or X.size(0) == 1
46
- Y = torch.empty((k, ydim), device=X.device, dtype=X.dtype)
47
- BLOCK_N = 128
48
- BLOCK_K = 128
49
- grid = ydim // BLOCK_N, k
50
- _single2scatter[grid](
51
- X, X.stride(0), X.stride(1),
52
- W, W.stride(0), W.stride(1), W.stride(2),
53
- Y, Y.stride(0), Y.stride(1),
54
- expert_idxs,
55
- FAN_OUT=Y.size(0) // X.size(0),
56
- K=xdim, N=ydim, E=E,
57
- BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
58
- ACC_TYPE=tl.float32
59
- )
60
- return Y