File size: 2,716 Bytes
9c4ca75 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
# Copyright 2024 Databricks
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import pytest
import torch
from megablocks import ops
PADDED_GATHER_TESTS = (
(4, 2, 2, 1),
(4, 2, 2, 2),
(1024, 1, 4, 1),
(1024, 1, 4, 2),
(1024, 1, 4, 4),
(1024, 1, 64, 1),
(1024, 1, 64, 2),
(1024, 1, 64, 4),
(1024, 1, 128, 1),
(1024, 1, 128, 2),
(1024, 1, 128, 4),
(1024, 1536, 4, 1),
(1024, 1536, 4, 2),
(1024, 1536, 4, 4),
(1024, 1536, 64, 1),
(1024, 1536, 64, 2),
(1024, 1536, 64, 4),
(1024, 1536, 128, 1),
(1024, 1536, 128, 2),
(1024, 1536, 128, 4),
(16384, 768, 4, 1),
(16384, 768, 4, 2),
(16384, 768, 4, 4),
(16384, 768, 64, 1),
(16384, 768, 64, 2),
(16384, 768, 64, 4),
(16384, 768, 128, 1),
(16384, 768, 128, 2),
(16384, 768, 128, 4),
(16384, 1, 4, 1),
(16384, 1, 4, 2),
(16384, 1, 4, 4),
(16384, 1, 64, 1),
(16384, 1, 64, 2),
(16384, 1, 64, 4),
(16384, 1, 128, 1),
(16384, 1, 128, 2),
(16384, 1, 128, 4),
)
@pytest.mark.gpu
@pytest.mark.parametrize(('sl', 'hs', 'ne', 'top_k'), PADDED_GATHER_TESTS)
def testPaddedGather(sl: int, hs: int, ne: int, top_k: int):
# Create the data and indices.
x = torch.randn((sl, hs)).cuda().half()
# Randomly assign tokens to experts.
top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
bin_ids, indices = ops.sort(top_expert)
tokens_per_expert = ops.histogram(top_expert, ne)
padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
bins = ops.inclusive_cumsum(tokens_per_expert, 0)
def padded_gather(
x: torch.Tensor,
indices: torch.Tensor,
bin_ids: torch.Tensor,
bins: torch.Tensor,
padded_bins: torch.Tensor,
top_k: int,
):
x = x.cpu().numpy()
indices = indices.cpu().numpy()
bin_ids = bin_ids.cpu().numpy()
bins = bins.cpu().numpy()
padded_bins = padded_bins.cpu().numpy()
out = np.zeros((padded_bins[-1], hs))
in_idx = 0
for i, end in enumerate(bins):
out_idx = 0 if i == 0 else padded_bins[i - 1]
end = bins[i]
while in_idx < end:
load_idx = indices[in_idx] // top_k
out[out_idx, :] = x[load_idx, :]
in_idx += 1
out_idx += 1
return torch.from_numpy(out).cuda().half()
out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
expected_out = padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
assert torch.all(torch.eq(out, expected_out))
|