|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import triton |
|
|
import triton.language as tl |
|
|
|
|
|
|
|
|
def assert_is_tensor(x, ndim): |
|
|
if x.ndim != ndim: |
|
|
raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor') |
|
|
|
|
|
|
|
|
def assert_is_matrix(x): |
|
|
assert_is_tensor(x, 2) |
|
|
|
|
|
|
|
|
def assert_is_vector(x): |
|
|
if x.ndim != 1: |
|
|
raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor') |
|
|
|
|
|
|
|
|
def assert_equal(a, b): |
|
|
if a != b: |
|
|
raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@triton.autotune( |
|
|
configs=[ |
|
|
triton.Config({'BLOCK_X': 64}, num_warps=2), |
|
|
triton.Config({'BLOCK_X': 128}, num_warps=2), |
|
|
triton.Config({'BLOCK_X': 256}, num_warps=2), |
|
|
triton.Config({'BLOCK_X': 128}, num_warps=4), |
|
|
triton.Config({'BLOCK_X': 256}, num_warps=4), |
|
|
], |
|
|
key=['NUM_COLUMNS'], |
|
|
) |
|
|
@triton.jit |
|
|
def _padded_copy( |
|
|
a, |
|
|
b, |
|
|
indices, |
|
|
bin_ids, |
|
|
weights, |
|
|
bins, |
|
|
padded_bins, |
|
|
NUM_COLUMNS: tl.constexpr, |
|
|
TOP_K: tl.constexpr, |
|
|
BLOCK_X: tl.constexpr, |
|
|
A_TO_B: tl.constexpr, |
|
|
SCALE: tl.constexpr, |
|
|
): |
|
|
|
|
|
index_a = tl.load(indices + tl.program_id(0)) |
|
|
|
|
|
|
|
|
|
|
|
bin_idx = tl.load(bin_ids + tl.program_id(0)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
offset_in_bin = tl.program_id(0) |
|
|
if bin_idx > 0: |
|
|
offset_in_bin -= tl.load(bins + bin_idx - 1) |
|
|
|
|
|
|
|
|
index_b = offset_in_bin |
|
|
if bin_idx > 0: |
|
|
index_b += tl.load(padded_bins + bin_idx - 1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
offset = index_a // TOP_K if A_TO_B else index_a |
|
|
a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) |
|
|
b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) |
|
|
offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) |
|
|
|
|
|
|
|
|
scale = tl.load(weights + index_a) if SCALE else 1 |
|
|
|
|
|
|
|
|
iptr = a if A_TO_B else b |
|
|
optr = b if A_TO_B else a |
|
|
|
|
|
iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) |
|
|
for _ in range(iterations): |
|
|
mask = offsets < NUM_COLUMNS |
|
|
x = tl.load(iptr + offsets, mask=mask) |
|
|
x = x.to(tl.float32) * scale.to(tl.float32) |
|
|
|
|
|
tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) |
|
|
|
|
|
offsets += BLOCK_X |
|
|
|
|
|
|
|
|
def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k): |
|
|
|
|
|
assert_is_matrix(x) |
|
|
assert_is_vector(indices) |
|
|
assert_is_vector(bin_ids) |
|
|
assert_is_vector(bins) |
|
|
assert_is_vector(padded_bins) |
|
|
assert_equal(indices.shape[0], x.shape[0] * top_k) |
|
|
assert_equal(bin_ids.shape[0], x.shape[0] * top_k) |
|
|
assert_equal(bins.size(), padded_bins.size()) |
|
|
|
|
|
if weights is not None: |
|
|
assert_equal(weights.shape[0], x.shape[0] * top_k) |
|
|
|
|
|
|
|
|
|
|
|
output_rows = padded_bins[-1].cpu().item() |
|
|
out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) |
|
|
_padded_copy[(indices.shape[0],)]( |
|
|
x, |
|
|
out, |
|
|
indices, |
|
|
bin_ids, |
|
|
weights, |
|
|
bins, |
|
|
padded_bins, |
|
|
NUM_COLUMNS=x.shape[1], |
|
|
A_TO_B=True, |
|
|
TOP_K=top_k, |
|
|
SCALE=weights is not None, |
|
|
) |
|
|
return out |
|
|
|
|
|
|
|
|
def gather(x, indices, bin_ids, weights, bins, top_k): |
|
|
|
|
|
assert_is_matrix(x) |
|
|
assert_is_vector(indices) |
|
|
assert_is_vector(bin_ids) |
|
|
assert_is_vector(bins) |
|
|
assert_equal(indices.shape[0], x.shape[0] * top_k) |
|
|
assert_equal(bin_ids.shape[0], x.shape[0] * top_k) |
|
|
|
|
|
if weights is not None: |
|
|
assert_equal(weights.shape[0], x.shape[0] * top_k) |
|
|
|
|
|
|
|
|
|
|
|
output_rows = x.shape[0] * top_k |
|
|
out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) |
|
|
_padded_copy[(indices.shape[0],)]( |
|
|
x, |
|
|
out, |
|
|
indices, |
|
|
bin_ids, |
|
|
weights, |
|
|
bins, |
|
|
bins, |
|
|
NUM_COLUMNS=x.shape[1], |
|
|
A_TO_B=True, |
|
|
TOP_K=top_k, |
|
|
SCALE=weights is not None, |
|
|
) |
|
|
return out |
|
|
|
|
|
|
|
|
def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k): |
|
|
|
|
|
assert_is_matrix(x) |
|
|
assert_is_vector(indices) |
|
|
assert_is_vector(bin_ids) |
|
|
assert_is_vector(bins) |
|
|
assert_is_vector(padded_bins) |
|
|
assert_equal(indices.shape[0], bin_ids.shape[0]) |
|
|
assert_equal(bins.size(), padded_bins.size()) |
|
|
|
|
|
if weights is not None: |
|
|
assert_equal(indices.shape[0], weights.shape[0]) |
|
|
|
|
|
tokens = indices.shape[0] // top_k |
|
|
out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device) |
|
|
_padded_copy[(indices.shape[0],)]( |
|
|
out, |
|
|
x, |
|
|
indices, |
|
|
bin_ids, |
|
|
weights, |
|
|
bins, |
|
|
padded_bins, |
|
|
NUM_COLUMNS=x.shape[1], |
|
|
A_TO_B=False, |
|
|
TOP_K=top_k, |
|
|
SCALE=weights is not None, |
|
|
) |
|
|
|
|
|
|
|
|
return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1]) |
|
|
|
|
|
|
|
|
def scatter(x, indices, bin_ids, weights, bins, top_k): |
|
|
return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@triton.autotune( |
|
|
configs=[ |
|
|
triton.Config({'BLOCK_X': 64}, num_warps=2), |
|
|
triton.Config({'BLOCK_X': 128}, num_warps=2), |
|
|
triton.Config({'BLOCK_X': 256}, num_warps=2), |
|
|
triton.Config({'BLOCK_X': 128}, num_warps=4), |
|
|
triton.Config({'BLOCK_X': 256}, num_warps=4), |
|
|
], |
|
|
key=['NUM_COLUMNS'], |
|
|
) |
|
|
@triton.jit |
|
|
def _padded_copy_wgrad( |
|
|
x, |
|
|
grad, |
|
|
wgrad, |
|
|
indices, |
|
|
bin_ids, |
|
|
bins, |
|
|
padded_bins, |
|
|
NUM_COLUMNS: tl.constexpr, |
|
|
TOP_K: tl.constexpr, |
|
|
BLOCK_X: tl.constexpr, |
|
|
): |
|
|
|
|
|
index_out = tl.load(indices + tl.program_id(0)) |
|
|
|
|
|
|
|
|
|
|
|
bin_idx = tl.load(bin_ids + tl.program_id(0)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
offset_in_bin = tl.program_id(0) |
|
|
if bin_idx > 0: |
|
|
offset_in_bin -= tl.load(bins + bin_idx - 1) |
|
|
|
|
|
|
|
|
index_x = offset_in_bin |
|
|
if bin_idx > 0: |
|
|
index_x += tl.load(padded_bins + bin_idx - 1) |
|
|
|
|
|
|
|
|
wgrad += index_out |
|
|
grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) |
|
|
x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) |
|
|
offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) |
|
|
|
|
|
acc = tl.zeros((BLOCK_X,), dtype=tl.float32) |
|
|
iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) |
|
|
for _ in range(iterations): |
|
|
mask = offsets < NUM_COLUMNS |
|
|
data = tl.load(x + offsets, mask=mask).to(tl.float32) |
|
|
scale = tl.load(grad + offsets, mask=mask).to(tl.float32) |
|
|
acc += data * scale |
|
|
offsets += BLOCK_X |
|
|
|
|
|
|
|
|
out = tl.sum(acc).to(wgrad.dtype.element_ty) |
|
|
tl.store(wgrad, out) |
|
|
|
|
|
|
|
|
def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k): |
|
|
|
|
|
assert_is_matrix(x) |
|
|
assert_is_matrix(grad) |
|
|
assert_is_vector(indices) |
|
|
assert_is_vector(bin_ids) |
|
|
assert_is_vector(bins) |
|
|
assert_is_vector(padded_bins) |
|
|
assert_equal(indices.shape[0], bin_ids.shape[0]) |
|
|
assert_equal(bins.size(), padded_bins.size()) |
|
|
|
|
|
tokens = indices.shape[0] // top_k |
|
|
out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device) |
|
|
_padded_copy_wgrad[(indices.shape[0],)]( |
|
|
x, |
|
|
grad, |
|
|
out, |
|
|
indices, |
|
|
bin_ids, |
|
|
bins, |
|
|
padded_bins, |
|
|
NUM_COLUMNS=x.shape[1], |
|
|
TOP_K=top_k, |
|
|
) |
|
|
return out |
|
|
|
|
|
|
|
|
def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k): |
|
|
return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@triton.autotune( |
|
|
configs=[ |
|
|
triton.Config({'BLOCK_X': 64}, num_warps=2), |
|
|
triton.Config({'BLOCK_X': 128}, num_warps=2), |
|
|
triton.Config({'BLOCK_X': 256}, num_warps=2), |
|
|
triton.Config({'BLOCK_X': 128}, num_warps=4), |
|
|
triton.Config({'BLOCK_X': 256}, num_warps=4), |
|
|
], |
|
|
key=['NUM_COLUMNS'], |
|
|
) |
|
|
@triton.jit |
|
|
def _binned_copy( |
|
|
a, |
|
|
b, |
|
|
num_experts, |
|
|
expert_capacity, |
|
|
indices, |
|
|
weights, |
|
|
bins, |
|
|
NUM_COLUMNS: tl.constexpr, |
|
|
TOP_K: tl.constexpr, |
|
|
BLOCK_X: tl.constexpr, |
|
|
A_TO_B: tl.constexpr, |
|
|
SCALE: tl.constexpr, |
|
|
): |
|
|
|
|
|
expert_idx = tl.program_id(0) |
|
|
entry_idx = tl.program_id(1) |
|
|
|
|
|
|
|
|
index_b = expert_idx * expert_capacity + entry_idx |
|
|
|
|
|
|
|
|
|
|
|
start = 0 |
|
|
if expert_idx > 0: |
|
|
start = tl.load(bins + expert_idx - 1) |
|
|
end = tl.load(bins + expert_idx) |
|
|
num_tokens = end - start |
|
|
|
|
|
|
|
|
|
|
|
if entry_idx >= num_tokens: |
|
|
return |
|
|
index_a = tl.load(indices + start + entry_idx) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
offset = index_a // TOP_K if A_TO_B else index_a |
|
|
a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) |
|
|
b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) |
|
|
offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) |
|
|
|
|
|
|
|
|
scale = tl.load(weights + index_a) if SCALE else 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
iptr = a if A_TO_B else b |
|
|
optr = b if A_TO_B else a |
|
|
|
|
|
iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) |
|
|
for _ in range(iterations): |
|
|
mask = offsets < NUM_COLUMNS |
|
|
x = tl.load(iptr + offsets, mask=mask) |
|
|
x = x.to(tl.float32) * scale.to(tl.float32) |
|
|
|
|
|
tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) |
|
|
|
|
|
offsets += BLOCK_X |
|
|
|
|
|
|
|
|
def binned_gather(x, indices, weights, bins, expert_capacity, top_k): |
|
|
|
|
|
assert_is_matrix(x) |
|
|
assert_is_vector(indices) |
|
|
assert_is_vector(bins) |
|
|
assert_equal(indices.shape[0], x.shape[0] * top_k) |
|
|
|
|
|
if weights is not None: |
|
|
assert_equal(weights.shape[0], x.shape[0] * top_k) |
|
|
|
|
|
num_experts = bins.shape[0] |
|
|
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device) |
|
|
|
|
|
_binned_copy[(num_experts, expert_capacity)]( |
|
|
x, |
|
|
out, |
|
|
num_experts, |
|
|
expert_capacity, |
|
|
indices, |
|
|
weights, |
|
|
bins, |
|
|
NUM_COLUMNS=x.shape[1], |
|
|
A_TO_B=True, |
|
|
TOP_K=top_k, |
|
|
SCALE=weights is not None, |
|
|
) |
|
|
return out |
|
|
|
|
|
|
|
|
def binned_scatter(x, indices, weights, bins, top_k): |
|
|
|
|
|
assert_is_tensor(x, 3) |
|
|
assert_is_vector(indices) |
|
|
assert_is_vector(bins) |
|
|
assert_equal(bins.shape[0], x.shape[0]) |
|
|
|
|
|
if weights is not None: |
|
|
assert_equal(indices.shape[0], weights.shape[0]) |
|
|
|
|
|
num_experts, expert_capacity, hidden_size = x.shape |
|
|
tokens = indices.shape[0] // top_k |
|
|
out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device) |
|
|
_binned_copy[(num_experts, expert_capacity)]( |
|
|
out, |
|
|
x, |
|
|
num_experts, |
|
|
expert_capacity, |
|
|
indices, |
|
|
weights, |
|
|
bins, |
|
|
NUM_COLUMNS=hidden_size, |
|
|
A_TO_B=False, |
|
|
TOP_K=top_k, |
|
|
SCALE=weights is not None, |
|
|
) |
|
|
|
|
|
|
|
|
return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@triton.autotune( |
|
|
configs=[ |
|
|
triton.Config({'BLOCK_X': 64}, num_warps=2), |
|
|
triton.Config({'BLOCK_X': 128}, num_warps=2), |
|
|
triton.Config({'BLOCK_X': 256}, num_warps=2), |
|
|
triton.Config({'BLOCK_X': 128}, num_warps=4), |
|
|
triton.Config({'BLOCK_X': 256}, num_warps=4), |
|
|
], |
|
|
key=['NUM_COLUMNS'], |
|
|
) |
|
|
@triton.jit |
|
|
def _binned_copy_wgrad( |
|
|
x, |
|
|
grad, |
|
|
wgrad, |
|
|
num_experts, |
|
|
expert_capacity, |
|
|
indices, |
|
|
bins, |
|
|
NUM_COLUMNS: tl.constexpr, |
|
|
TOP_K: tl.constexpr, |
|
|
BLOCK_X: tl.constexpr, |
|
|
): |
|
|
|
|
|
expert_idx = tl.program_id(0) |
|
|
entry_idx = tl.program_id(1) |
|
|
|
|
|
|
|
|
index_x = expert_idx * expert_capacity + entry_idx |
|
|
|
|
|
|
|
|
|
|
|
start = 0 |
|
|
if expert_idx > 0: |
|
|
start = tl.load(bins + expert_idx - 1) |
|
|
end = tl.load(bins + expert_idx) |
|
|
num_tokens = end - start |
|
|
|
|
|
|
|
|
|
|
|
if entry_idx >= num_tokens: |
|
|
return |
|
|
index_out = tl.load(indices + start + entry_idx) |
|
|
|
|
|
|
|
|
wgrad += index_out |
|
|
grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) |
|
|
x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) |
|
|
offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) |
|
|
|
|
|
acc = tl.zeros((BLOCK_X,), dtype=tl.float32) |
|
|
iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) |
|
|
for _ in range(iterations): |
|
|
mask = offsets < NUM_COLUMNS |
|
|
data = tl.load(x + offsets, mask=mask).to(tl.float32) |
|
|
scale = tl.load(grad + offsets, mask=mask).to(tl.float32) |
|
|
acc += data * scale |
|
|
offsets += BLOCK_X |
|
|
|
|
|
|
|
|
out = tl.sum(acc).to(wgrad.dtype.element_ty) |
|
|
tl.store(wgrad, out) |
|
|
|
|
|
|
|
|
def binned_scatter_wgrad(x, grad, indices, bins, top_k): |
|
|
|
|
|
assert_is_tensor(x, 3) |
|
|
assert_is_matrix(grad) |
|
|
assert_is_vector(indices) |
|
|
assert_is_vector(bins) |
|
|
assert_equal(bins.shape[0], x.shape[0]) |
|
|
|
|
|
num_experts, expert_capacity, hidden_size = x.shape |
|
|
tokens = indices.shape[0] // top_k |
|
|
out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device) |
|
|
_binned_copy_wgrad[(num_experts, expert_capacity)]( |
|
|
x, |
|
|
grad, |
|
|
out, |
|
|
num_experts, |
|
|
expert_capacity, |
|
|
indices, |
|
|
bins, |
|
|
NUM_COLUMNS=hidden_size, |
|
|
TOP_K=top_k, |
|
|
) |
|
|
return out |
|
|
|