leonardlin commited on
Commit
1e407f0
·
1 Parent(s): eb55039

Add ROCm build artifacts and HIP backend

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +3 -1
  2. README.md +2 -2
  3. build.py +73 -0
  4. build.sh +9 -0
  5. build/torch28-cxx11-rocm64-x86_64-linux/megablocks/__init__.py +202 -0
  6. build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/__init__.py +10 -0
  7. build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/activation_fn.py +33 -0
  8. build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/all_to_all.py +54 -0
  9. build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/arguments.py +101 -0
  10. build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/common.py +26 -0
  11. build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/dmlp_registry.py +42 -0
  12. build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/dmoe.py +337 -0
  13. build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/gelu.py +52 -0
  14. build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/glu.py +244 -0
  15. build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/memory_test.py +103 -0
  16. build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/memory_test.sh +12 -0
  17. build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/mlp.py +587 -0
  18. build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/moe.py +507 -0
  19. build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/mpu.py +94 -0
  20. build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/router.py +116 -0
  21. build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/sharedexpert_registry.py +32 -0
  22. build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_megablocks_rocm.so +3 -0
  23. build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_ops.py +18 -0
  24. build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_version.py +6 -0
  25. build/torch28-cxx11-rocm64-x86_64-linux/megablocks/backend/__init__.py +2 -0
  26. build/torch28-cxx11-rocm64-x86_64-linux/megablocks/backend/kernels.py +557 -0
  27. build/torch28-cxx11-rocm64-x86_64-linux/megablocks/bak.__init__.py +23 -0
  28. build/torch28-cxx11-rocm64-x86_64-linux/megablocks/benchmark_util.py +35 -0
  29. build/torch28-cxx11-rocm64-x86_64-linux/megablocks/grouped_gemm/__init__.py +2 -0
  30. build/torch28-cxx11-rocm64-x86_64-linux/megablocks/grouped_gemm/backend.py +33 -0
  31. build/torch28-cxx11-rocm64-x86_64-linux/megablocks/grouped_gemm/ops.py +33 -0
  32. build/torch28-cxx11-rocm64-x86_64-linux/megablocks/grouped_gemm_util.py +31 -0
  33. build/torch28-cxx11-rocm64-x86_64-linux/megablocks/layers.py +1225 -0
  34. build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/__init__.py +35 -0
  35. build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +63 -0
  36. build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/all_to_all_benchmark.sh +12 -0
  37. build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/binned_gather.py +37 -0
  38. build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/binned_scatter.py +59 -0
  39. build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/cumsum.py +52 -0
  40. build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/gather.py +38 -0
  41. build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/histogram.py +27 -0
  42. build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/histogram_benchmark.py +78 -0
  43. build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/matmul_benchmark.py +415 -0
  44. build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/padded_gather.py +55 -0
  45. build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/padded_scatter.py +98 -0
  46. build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py +66 -0
  47. build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/permute_benchmark.py +149 -0
  48. build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/repeat.py +10 -0
  49. build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/replicate.py +36 -0
  50. build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/round_up.py +14 -0
.gitignore CHANGED
@@ -2,4 +2,6 @@
2
  __pycache__
3
  .bak
4
  megablocks-moe/.bak
5
- .pytest_cache
 
 
 
2
  __pycache__
3
  .bak
4
  megablocks-moe/.bak
5
+ .pytest_cache
6
+ .readme_example.py.swp
7
+ .torch_extensions/
README.md CHANGED
@@ -7,7 +7,7 @@ tags:
7
  ## Quickstart
8
 
9
  ```bash
10
- uv run https://huggingface.co/kernels-community/megablocks/raw/main/readme_example.py
11
  ```
12
 
13
  ```python
@@ -30,7 +30,7 @@ torch.manual_seed(42)
30
  torch.cuda.manual_seed(42)
31
 
32
  # Download optimized kernels from the Hugging Face hub
33
- megablocks = get_kernel("kernels-community/megablocks")
34
  print("MegaBlocks kernel downloaded successfully.")
35
 
36
  model = megablocks.layers.MegaBlocksMoeMLP()
 
7
  ## Quickstart
8
 
9
  ```bash
10
+ uv run https://huggingface.co/shisa-ai/megablocks-hip/raw/main/readme_example.py
11
  ```
12
 
13
  ```python
 
30
  torch.cuda.manual_seed(42)
31
 
32
  # Download optimized kernels from the Hugging Face hub
33
+ megablocks = get_kernel("shisa-ai/megablocks-hip")
34
  print("MegaBlocks kernel downloaded successfully.")
35
 
36
  model = megablocks.layers.MegaBlocksMoeMLP()
build.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import pathlib
4
+ import shutil
5
+
6
+ from torch.utils.cpp_extension import load
7
+
8
+ try:
9
+ from kernels.utils import build_variant
10
+ except ImportError: # fallback when kernels is unavailable
11
+ build_variant = None
12
+
13
+ repo = pathlib.Path(__file__).resolve().parent
14
+ os.environ.setdefault("TORCH_EXTENSIONS_DIR", str(repo / ".torch_extensions"))
15
+
16
+ sources = [
17
+ repo / "torch-ext" / "torch_binding.cpp",
18
+ repo / "csrc" / "new_cumsum.cu",
19
+ repo / "csrc" / "new_histogram.cu",
20
+ repo / "csrc" / "new_indices.cu",
21
+ repo / "csrc" / "new_replicate.cu",
22
+ repo / "csrc" / "new_sort.cu",
23
+ repo / "csrc" / "grouped_gemm" / "grouped_gemm.cu",
24
+ ]
25
+
26
+ mod = load(
27
+ name="_megablocks_rocm",
28
+ sources=[str(s) for s in sources],
29
+ extra_include_paths=[str(repo / "csrc")],
30
+ extra_cflags=["-O3", "-std=c++17"],
31
+ extra_cuda_cflags=["-O3"], # torch switches this to hipcc flags on ROCm builds
32
+ extra_ldflags=["-lhipblaslt"],
33
+ verbose=True,
34
+ is_python_module=False,
35
+ )
36
+
37
+ module_path = pathlib.Path(mod if isinstance(mod, str) else mod.__file__)
38
+ print("built:", module_path)
39
+
40
+ if build_variant is None:
41
+ print("kernels not available; skipping package staging")
42
+ else:
43
+ variant = build_variant()
44
+ package_root = repo / "build" / variant / "megablocks"
45
+ if package_root.exists():
46
+ shutil.rmtree(package_root)
47
+ shutil.copytree(
48
+ repo / "torch-ext" / "megablocks",
49
+ package_root,
50
+ ignore=shutil.ignore_patterns("__pycache__"),
51
+ )
52
+ ops_py = package_root / "_ops.py"
53
+ ops_py.write_text('''
54
+ import torch
55
+ from pathlib import Path
56
+
57
+ _LIB_NAME = "_megablocks_rocm.so"
58
+
59
+
60
+ def _load_ops():
61
+ lib_path = Path(__file__).with_name(_LIB_NAME)
62
+ torch.ops.load_library(str(lib_path))
63
+ return torch.ops._megablocks_rocm
64
+
65
+
66
+ ops = _load_ops()
67
+
68
+
69
+ def add_op_namespace_prefix(op_name: str) -> str:
70
+ return f"_megablocks_rocm::{op_name}"
71
+ ''')
72
+ shutil.copy2(module_path, package_root / module_path.name)
73
+ print(f"staged local kernel under build/{variant}/megablocks")
build.sh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # export TORCH_EXTENSIONS_DIR=/root/shisa-v2/train/v2.1/megablocks.kernels-community/.torch_extensions; export ROCM_HOME=/opt/rocm-6.4.1; export HIP_HOME=$ROCM_HOME; export TORCH_HIP_ARCH_LIST=gfx942; export HSA_OVERRIDE_GFX_VERSION=gfx942; python megablocks.kernels-community/build.py
2
+
3
+ # 3-4min build
4
+ export ROCM_HOME=/opt/rocm-6.4.1
5
+ export HIP_HOME=$ROCM_HOME
6
+ export TORCH_HIP_ARCH_LIST=gfx942
7
+ export HSA_OVERRIDE_GFX_VERSION=gfx942
8
+ export TORCH_EXTENSIONS_DIR="$PWD/megablocks.kernels-community/.torch_extensions"
9
+ python megablocks.kernels-community/build.py
build/torch28-cxx11-rocm64-x86_64-linux/megablocks/__init__.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import torch
5
+
6
+ from ._ops import ops
7
+
8
+ from .grouped_gemm import backend as gg_backend
9
+ from .grouped_gemm import ops as gg_ops
10
+
11
+
12
+ from ._layers.arguments import Arguments
13
+ from ._layers.dmoe import ParallelDroplessMLP, dMoE
14
+ from ._layers.glu import SparseGLU
15
+ from ._layers.mlp import MLP, SparseMLP
16
+ from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss
17
+
18
+ from . import layers
19
+
20
+ # This section contains the direct kernel exports (not inlcuded in the original code)
21
+ def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
22
+ """
23
+ Compute exclusive cumulative sum along the specified dimension.
24
+
25
+ Args:
26
+ x: Input tensor
27
+ dim: Dimension along which to compute cumsum
28
+ out: Output tensor (modified in-place)
29
+
30
+ Returns:
31
+ The output tensor
32
+ """
33
+ result = ops.exclusive_cumsum(x, dim)
34
+ out.copy_(result)
35
+ return out
36
+
37
+
38
+ def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
39
+ """
40
+ Compute inclusive cumulative sum along the specified dimension.
41
+
42
+ Args:
43
+ x: Input tensor
44
+ dim: Dimension along which to compute cumsum
45
+ out: Output tensor (modified in-place)
46
+
47
+ Returns:
48
+ The output tensor
49
+ """
50
+ result = ops.inclusive_cumsum(x, dim)
51
+ out.copy_(result)
52
+ return out
53
+
54
+
55
+ def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
56
+ """
57
+ Compute histogram of input tensor values.
58
+
59
+ Args:
60
+ x: Input tensor
61
+ num_bins: Number of histogram bins
62
+
63
+ Returns:
64
+ Histogram tensor with counts for each bin
65
+ """
66
+ return ops.histogram(x, num_bins)
67
+
68
+
69
+ def indices(
70
+ padded_bins: torch.Tensor,
71
+ block_size: int,
72
+ output_block_rows: int,
73
+ output_block_columns: int,
74
+ ) -> torch.Tensor:
75
+ """
76
+ Construct indices from padded bins for sparse operations.
77
+
78
+ Args:
79
+ padded_bins: Tensor containing bin boundaries
80
+ block_size: Size of each block
81
+ output_block_rows: Number of rows in output blocks
82
+ output_block_columns: Number of columns in output blocks
83
+
84
+ Returns:
85
+ Tensor containing constructed indices
86
+ """
87
+ return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns)
88
+
89
+
90
+ def replicate_forward(
91
+ x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
92
+ ) -> torch.Tensor:
93
+ """
94
+ Forward pass of replicate operation - replicate values according to bin sizes.
95
+
96
+ Args:
97
+ x: Input tensor with values to replicate
98
+ bins: Tensor containing bin sizes
99
+ out: Output tensor (modified in-place)
100
+
101
+ Returns:
102
+ The output tensor
103
+ """
104
+ return ops.replicate_forward(x, bins, out)
105
+
106
+
107
+ def replicate_backward(
108
+ grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
109
+ ) -> torch.Tensor:
110
+ """
111
+ Backward pass of replicate operation - reduce gradients back to bins.
112
+
113
+ Args:
114
+ grad: Gradient tensor to reduce
115
+ bins: Tensor containing bin sizes
116
+ out: Output tensor (modified in-place)
117
+
118
+ Returns:
119
+ The output tensor
120
+ """
121
+ return ops.replicate_backward(grad, bins, out)
122
+
123
+
124
+ def sort(
125
+ x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor
126
+ ) -> torch.Tensor:
127
+ """
128
+ Radix sort with index tracking.
129
+
130
+ Args:
131
+ x: Input tensor to sort
132
+ end_bit: Number of bits to consider in sorting
133
+ x_out: Output tensor for sorted values
134
+ iota_out: Output tensor for sorted indices
135
+
136
+ Returns:
137
+ The sorted values tensor
138
+ """
139
+ return ops.sort(x, end_bit, x_out, iota_out)
140
+
141
+
142
+ # Convenience functions for common use cases
143
+ def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor:
144
+ """
145
+ Compute cumulative sum with automatic output allocation.
146
+
147
+ Args:
148
+ x: Input tensor
149
+ dim: Dimension along which to compute cumsum (default: last dimension)
150
+ exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum
151
+
152
+ Returns:
153
+ New tensor containing the cumulative sum
154
+ """
155
+ out = torch.empty_like(x)
156
+ if exclusive:
157
+ return exclusive_cumsum(x, dim, out)
158
+ else:
159
+ return inclusive_cumsum(x, dim, out)
160
+
161
+
162
+ def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]:
163
+ """
164
+ Sort tensor and return both sorted values and indices.
165
+
166
+ Args:
167
+ x: Input tensor to sort
168
+ end_bit: Number of bits to consider in sorting
169
+
170
+ Returns:
171
+ Tuple of (sorted_values, sorted_indices)
172
+ """
173
+ x_out = torch.empty_like(x)
174
+ iota_out = torch.empty_like(x)
175
+ sort(x, end_bit, x_out, iota_out)
176
+ return x_out, iota_out
177
+
178
+
179
+ # Export public API
180
+ __all__ = [
181
+ "MyReplacementLayer",
182
+ # Direct kernel exports
183
+ "exclusive_cumsum",
184
+ "inclusive_cumsum",
185
+ "histogram",
186
+ "indices",
187
+ "replicate_forward",
188
+ "replicate_backward",
189
+ "sort",
190
+ "cumsum",
191
+ "argsort",
192
+ # Original exports
193
+ "Arguments",
194
+ "ParallelDroplessMLP",
195
+ "dMoE",
196
+ "SparseGLU",
197
+ "MLP",
198
+ "SparseMLP",
199
+ "MoE",
200
+ "ParallelMLP",
201
+ "get_load_balancing_loss",
202
+ ]
build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # from megablocks.layers.dmoe import dMoE
5
+ from .moe import MoE
6
+
7
+ __all__ = [
8
+ 'MoE',
9
+ # 'dMoE',
10
+ ]
build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/activation_fn.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Any, Callable, Union
5
+
6
+ import torch
7
+ from ..stk import Matrix
8
+
9
+
10
+ def act_fn(
11
+ x: Matrix,
12
+ function: Callable,
13
+ return_grad_fn: bool = False,
14
+ **kwargs,
15
+ ) -> Union[tuple[Matrix, Any] | Matrix]:
16
+ assert isinstance(x, Matrix)
17
+ with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn):
18
+ if return_grad_fn:
19
+ x.data.requires_grad = True
20
+ out = function(x.data, **kwargs)
21
+ y = Matrix(
22
+ x.size(),
23
+ out,
24
+ x.row_indices,
25
+ x.column_indices,
26
+ x.offsets,
27
+ x.column_indices_t,
28
+ x.offsets_t,
29
+ x.block_offsets_t,
30
+ )
31
+ if return_grad_fn:
32
+ return y, out.backward
33
+ return y
build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/all_to_all.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+
7
+
8
+ class AllToAllOp(torch.autograd.Function):
9
+
10
+ @staticmethod
11
+ def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
12
+ out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype)
13
+
14
+ ctx.input_shape = x.shape
15
+ ctx.output_split_sizes = output_split_sizes
16
+ ctx.input_split_sizes = input_split_sizes
17
+ ctx.group = group
18
+ handle = dist.all_to_all_single(
19
+ out,
20
+ x,
21
+ output_split_sizes=output_split_sizes,
22
+ input_split_sizes=input_split_sizes,
23
+ group=group,
24
+ async_op=async_op,
25
+ )
26
+ return out, handle
27
+
28
+ @staticmethod
29
+ def backward(ctx, grad, _):
30
+ if ctx.needs_input_grad[0]:
31
+ out = torch.empty(
32
+ ctx.input_shape,
33
+ device=grad.device,
34
+ dtype=grad.dtype,
35
+ )
36
+ dist.all_to_all_single(
37
+ out,
38
+ grad,
39
+ output_split_sizes=ctx.input_split_sizes,
40
+ input_split_sizes=ctx.output_split_sizes,
41
+ group=ctx.group,
42
+ )
43
+ return out, None, None, None, None
44
+ return None, None, None, None, None
45
+
46
+
47
+ def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
48
+ return AllToAllOp.apply(
49
+ x,
50
+ output_split_sizes,
51
+ input_split_sizes,
52
+ group,
53
+ async_op,
54
+ )
build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/arguments.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import dataclasses
5
+ from functools import partial
6
+ from typing import Any, Callable, Optional, Union
7
+
8
+ import torch
9
+ import torch.distributed as dist
10
+ import torch.nn.functional as F
11
+
12
+ # import megablocks.grouped_gemm_util as grouped_gemm
13
+ from .. import grouped_gemm_util as grouped_gemm
14
+
15
+ # Type annotation for in-place Tensor initialization function.
16
+ InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]]
17
+
18
+ _ALLOWED_BITWIDTHS = (-1, 4, 8)
19
+
20
+ DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh')
21
+
22
+
23
+ @dataclasses.dataclass
24
+ class Arguments:
25
+ # Model arguments.
26
+ hidden_size: int = 1024
27
+ ffn_hidden_size: int = 4096
28
+ num_layers: int = 1
29
+ bias: bool = True
30
+ return_bias: bool = True
31
+ activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN
32
+
33
+ # MoE arguments.
34
+ moe_num_experts: int = 1
35
+ moe_top_k: int = 1
36
+ moe_capacity_factor: int = 1
37
+ moe_normalize_expert_weights: Optional[Union[int, float]] = None
38
+ moe_loss_weight: float = 0.1
39
+ moe_jitter_eps: Optional[float] = None
40
+ moe_lbl_in_fp32: bool = False
41
+
42
+ # Parallelism arguments.
43
+ moe_expert_model_parallelism: bool = False
44
+ expert_parallel_group: Optional[dist.ProcessGroup] = None
45
+ pipeline_model_parallel_size: int = 1
46
+ num_layers_per_virtual_pipeline_stage: Optional[int] = None
47
+
48
+ # Compute arguments.
49
+ memory_optimized_mlp: bool = False
50
+ mlp_type: str = 'mlp'
51
+ mlp_impl: str = 'sparse'
52
+
53
+ # Initialization arguments.
54
+ fp16: bool = True
55
+ bf16: bool = False
56
+ device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device)
57
+ init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02)
58
+ output_layer_init_method: InitFn = init_method
59
+
60
+ # Benchmarking arguments.
61
+ uniform_expert_assignment: bool = False
62
+
63
+ # shared expert arguments
64
+ shared_expert: bool = False # enable using shared expert
65
+ fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8))
66
+ fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers
67
+ remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored
68
+ shared_expert_hidden_size: Optional[
69
+ int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size
70
+ shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used)
71
+
72
+ # Router Z-loss arguments
73
+ moe_zloss_weight: float = 0 # 1e-3 is a reasonable value
74
+ moe_zloss_in_fp32: bool = False
75
+
76
+ def __post_init__(self):
77
+ # Sparse MLP is not supported with triton >=3.2.0
78
+ # TODO: Remove this once sparse is supported with triton >=3.2.0
79
+ if self.__getattribute__('mlp_impl') == 'sparse':
80
+ try:
81
+ import triton
82
+ if triton.__version__ >= '3.2.0':
83
+ raise ValueError(
84
+ 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.',
85
+ )
86
+ except ImportError:
87
+ raise ImportError('Triton is required for sparse MLP implementation')
88
+
89
+ if self.__getattribute__('mlp_impl') == 'grouped':
90
+ grouped_gemm.assert_grouped_gemm_is_available()
91
+
92
+ if self.shared_expert_hidden_size is None:
93
+ self.shared_expert_hidden_size = self.ffn_hidden_size
94
+
95
+
96
+ def from_megatron(megatron_args: Any):
97
+ args = Arguments()
98
+ for field in dataclasses.fields(args):
99
+ if hasattr(megatron_args, field.name):
100
+ setattr(args, field.name, getattr(megatron_args, field.name))
101
+ return args
build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/common.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import torch
5
+
6
+ from .arguments import Arguments
7
+
8
+
9
+ def dtype(args: Arguments):
10
+ if args.fp16:
11
+ return torch.float16
12
+ elif args.bf16:
13
+ return torch.bfloat16
14
+ return None
15
+
16
+
17
+ def cast_if_autocast_enabled(tensor):
18
+ if torch.is_autocast_enabled():
19
+ if tensor.device.type == 'cuda':
20
+ dtype = torch.get_autocast_gpu_dtype()
21
+ elif tensor.device.type == 'cpu':
22
+ dtype = torch.get_autocast_cpu_dtype()
23
+ else:
24
+ raise NotImplementedError()
25
+ return tensor.to(dtype=dtype)
26
+ return tensor
build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/dmlp_registry.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Union
5
+
6
+ from . import glu, mlp
7
+ from .arguments import Arguments
8
+
9
+ MlpType = Union[mlp.SparseMLP, glu.SparseGLU]
10
+
11
+ _REGISTRY = {
12
+ 'mlp': {
13
+ 'grouped': mlp.GroupedMLP,
14
+ 'sparse': mlp.SparseMLP,
15
+ },
16
+ 'glu': {
17
+ 'grouped': glu.GroupedGLU,
18
+ 'sparse': glu.SparseGLU,
19
+ },
20
+ }
21
+
22
+
23
+ def get(args: Arguments) -> MlpType:
24
+ """Returns an MLP for use in a dMoE instance.
25
+
26
+ Uses the provided arguments to instantiate the appropriate
27
+ MLP instance. This only contains MLPs for use in dMoEs
28
+ (ie. only for the dropless versions of MoEs).
29
+
30
+ Args:
31
+ args: propagated Arguments dataclass.
32
+
33
+ Returns:
34
+ An instantiated MLP constructed using the input args.
35
+ """
36
+ if args.mlp_type not in _REGISTRY:
37
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
38
+
39
+ if args.mlp_impl not in _REGISTRY[args.mlp_type]:
40
+ raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',)
41
+
42
+ return _REGISTRY[args.mlp_type][args.mlp_impl](args)
build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/dmoe.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ # try:
8
+ # import stk.ops
9
+ # except ImportError:
10
+ # import warnings
11
+ # warnings.warn(
12
+ # 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.',
13
+ # )
14
+
15
+ # import megablocks.ops as ops
16
+ # # from megablocks.ops import ops
17
+ # from megablocks.layers import common, dmlp_registry, moe, mpu
18
+ # from megablocks.layers.arguments import Arguments
19
+
20
+ from .. import stk
21
+ from .. import ops
22
+ from . import common, dmlp_registry, moe, mpu
23
+ from .arguments import Arguments
24
+
25
+ def promote_scalar(x):
26
+ return x.view(1) if not len(x.size()) else x
27
+
28
+
29
+ class ParallelDroplessMLP(moe.ParallelMLP):
30
+
31
+ def __init__(self, args: Arguments):
32
+ super(ParallelDroplessMLP, self).__init__(args)
33
+ self.hidden_size = args.hidden_size
34
+ self.ffn_hidden_size = mpu.features_per_rank(args)
35
+ self.blocking = 128
36
+ self.mlp = dmlp_registry.get(args)
37
+
38
+ # Calculate the number of bits needed to represent the column indices
39
+ # in the intermediate sparse matrix.
40
+ max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking)
41
+ self.transpose_sort_end_bit = max(
42
+ int(np.ceil(np.log2(max_column_index))),
43
+ 1,
44
+ )
45
+
46
+ def sparse_transpose(self, size, row_indices, column_indices, offsets):
47
+ block_columns = size[1] // self.blocking
48
+
49
+ # Sort row indices by column indices to get the transposed matrix's
50
+ # column indices.
51
+ #
52
+ # NOTE: Our sort operation uses the same width indices as the input values.
53
+ # To avoid overflow when we have large activation matrices we cast to
54
+ # 32-bit before sorting.
55
+ _, gather_indices = ops.sort(
56
+ column_indices.int(),
57
+ self.transpose_sort_end_bit,
58
+ )
59
+
60
+ # There are a constant number of blocks in every row of the sparse matrix.
61
+ # A blocks offset is:
62
+ #
63
+ # row_index * blocks_per_row + column_index % blocks_per_row
64
+ #
65
+ # Once we have the block offsets ordered for transposition we can divide
66
+ # by blocks_per_row to get the transposed column indices.
67
+ column_indices_t = row_indices.gather(0, gather_indices.long())
68
+ block_offsets_t = gather_indices.int()
69
+
70
+ zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device)
71
+ nnz_per_column = ops.histogram(column_indices, block_columns)
72
+ nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0)
73
+ if nnz_per_column.dim() == 0:
74
+ # This addresses an edge case when ffn_hidden_size is equal to self.blocking.
75
+ nnz_per_column = nnz_per_column.unsqueeze(0)
76
+ offsets_t = torch.cat([zero, nnz_per_column])
77
+ return column_indices_t, offsets_t, block_offsets_t
78
+
79
+ def topology(self, x, padded_bins):
80
+ padded_tokens, _ = x.size()
81
+ assert padded_tokens % self.blocking == 0
82
+ if self.ffn_hidden_size % self.blocking != 0:
83
+ raise ValueError(
84
+ f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' +
85
+ f'the block size {self.blocking}. Please update your configuration.',
86
+ )
87
+
88
+ # Offsets for the sparse matrix. All rows have the
89
+ # same number of nonzero blocks dictated by the
90
+ # dimensionality of a single expert.
91
+ block_rows = padded_tokens // self.blocking
92
+ blocks_per_row = self.ffn_hidden_size // self.blocking
93
+ offsets = torch.arange(
94
+ 0,
95
+ block_rows * blocks_per_row + 1,
96
+ blocks_per_row,
97
+ dtype=torch.int32,
98
+ device=x.device,
99
+ )
100
+
101
+ # Indices for the sparse matrix. The indices for
102
+ # the intermediate matrix are dynamic depending
103
+ # on the mapping of tokens to experts.
104
+ column_indices = ops.topology(
105
+ padded_bins,
106
+ self.blocking,
107
+ block_rows,
108
+ blocks_per_row,
109
+ )
110
+
111
+ # TODO(tgale): This is unused. Remove the need for this in stk.
112
+ # For now, use meta init to save the device memory.
113
+ data = torch.empty(
114
+ column_indices.numel(),
115
+ self.blocking,
116
+ self.blocking,
117
+ dtype=common.dtype(self.args),
118
+ device='meta',
119
+ )
120
+ shape = (
121
+ padded_tokens,
122
+ self.ffn_hidden_size * mpu.experts_per_rank(self.args),
123
+ )
124
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
125
+ column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(
126
+ shape,
127
+ row_indices,
128
+ column_indices,
129
+ offsets,
130
+ )
131
+ return stk.Matrix(
132
+ shape,
133
+ data,
134
+ row_indices,
135
+ column_indices,
136
+ offsets,
137
+ column_indices_t,
138
+ offsets_t,
139
+ block_offsets_t,
140
+ )
141
+
142
+ def indices_and_padded_bins(self, top_experts):
143
+ # Sort the expert ids to produce the scatter/gather
144
+ # indices for the permutation.
145
+ top_experts = top_experts.int()
146
+ bin_ids, indices = ops.sort(top_experts, self.sort_end_bit)
147
+
148
+ # Histogram the expert ids to identify the number of
149
+ # tokens routed to each expert.
150
+ tokens_per_expert = ops.histogram(top_experts, self.num_experts)
151
+
152
+ # Round the token counts up to the block size used in
153
+ # the matrix muliplications. Caculate the starting
154
+ # position of each bin.
155
+ padded_tokens_per_expert = ops.round_up(
156
+ tokens_per_expert,
157
+ self.blocking,
158
+ )
159
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
160
+ padded_bins = promote_scalar(padded_bins)
161
+
162
+ # Calculate the bin bounds for the sorted tokens.
163
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
164
+ bins = promote_scalar(bins)
165
+ return indices, bin_ids, bins, padded_bins, tokens_per_expert
166
+
167
+ def sparse_forward_once(self, x, expert_weights, top_experts):
168
+ # x: [sl, bs, hs]
169
+ # expert_weights: [sl * bs, top-k]
170
+ # top_experts: [sl * bs, top-k]
171
+ expert_weights = expert_weights.flatten()
172
+ top_experts = top_experts.flatten()
173
+ with torch.no_grad():
174
+ indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts))
175
+
176
+ # Route the tokens for MoE computation.
177
+ x = x.view(-1, x.shape[-1])
178
+ x = ops.padded_gather(
179
+ x,
180
+ indices,
181
+ bin_ids,
182
+ bins,
183
+ padded_bins,
184
+ self.top_k,
185
+ )
186
+
187
+ # Create the sparse matrix topology.
188
+ with torch.no_grad():
189
+ topo = self.topology(x, padded_bins)
190
+
191
+ # Perform the expert computation.
192
+ x = self.mlp(x, topo)
193
+
194
+ # Un-route the data for the MoE output.
195
+ x = ops.padded_scatter(
196
+ x,
197
+ indices,
198
+ bin_ids,
199
+ expert_weights,
200
+ bins,
201
+ padded_bins,
202
+ self.top_k,
203
+ )
204
+ return x, tokens_per_expert
205
+
206
+ # For use in the base-class parallel_forward_once.
207
+ def sparse_permute_and_compute(
208
+ self,
209
+ x,
210
+ tokens_per_expert,
211
+ indices,
212
+ bin_ids,
213
+ expert_weights,
214
+ bins,
215
+ expert_capactiy, # unused
216
+ top_k,
217
+ ):
218
+
219
+ # Round the token counts up to the block size used in the matrix
220
+ # multiplication. Calculate the starting position of each bin.
221
+ padded_tokens_per_expert = ops.round_up(
222
+ tokens_per_expert,
223
+ self.blocking,
224
+ )
225
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
226
+ padded_bins = promote_scalar(padded_bins)
227
+
228
+ # Route the tokens for MoE computation.
229
+ x = x.view(-1, x.shape[-1])
230
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
231
+
232
+ # Create the sparse matrix topology.
233
+ with torch.no_grad():
234
+ topo = self.topology(x, padded_bins)
235
+
236
+ # Perform the expert computation.
237
+ x = self.mlp(x, topo)
238
+
239
+ # Un-route the data for the MoE output.
240
+ return ops.padded_scatter(
241
+ x,
242
+ indices,
243
+ bin_ids,
244
+ expert_weights,
245
+ bins,
246
+ padded_bins,
247
+ top_k,
248
+ )
249
+
250
+ def grouped_forward_once(self, x, expert_weights, top_experts):
251
+ # x: [sl, bs, hs]
252
+ # expert_weights: [sl * bs, top-k]
253
+ # top_experts: [sl * bs, top-k]
254
+ expert_weights = expert_weights.flatten()
255
+ top_experts = top_experts.flatten()
256
+ with torch.no_grad():
257
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
258
+
259
+ out = self.grouped_permute_and_compute(
260
+ x,
261
+ tokens_per_expert,
262
+ indices,
263
+ bin_ids,
264
+ expert_weights,
265
+ bins,
266
+ -1, # unused
267
+ self.args.moe_top_k,
268
+ )
269
+ return out, tokens_per_expert
270
+
271
+ def grouped_permute_and_compute(
272
+ self,
273
+ x,
274
+ tokens_per_expert,
275
+ indices,
276
+ bin_ids,
277
+ expert_weights,
278
+ bins,
279
+ expert_capactiy, # unused
280
+ top_k,
281
+ ):
282
+
283
+ # Route the tokens for MoE computation.
284
+ x = x.view(-1, x.shape[-1])
285
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
286
+
287
+ # Perform the expert computation.
288
+ x = self.mlp(x, tokens_per_expert)
289
+
290
+ # Un-route the data for the MoE output.
291
+ return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
292
+
293
+ def forward_once(self, x, expert_weights, top_experts):
294
+ if self.args.mlp_impl == 'sparse':
295
+ return self.sparse_forward_once(x, expert_weights, top_experts)
296
+ else:
297
+ return self.grouped_forward_once(x, expert_weights, top_experts)
298
+
299
+ def permute_and_compute(
300
+ self,
301
+ x,
302
+ tokens_per_expert,
303
+ indices,
304
+ bin_ids,
305
+ expert_weights,
306
+ bins,
307
+ expert_capactiy,
308
+ top_k,
309
+ ):
310
+ if self.args.mlp_impl == 'sparse':
311
+ return self.sparse_permute_and_compute(
312
+ x,
313
+ tokens_per_expert,
314
+ indices,
315
+ bin_ids,
316
+ expert_weights,
317
+ bins,
318
+ expert_capactiy,
319
+ top_k,
320
+ )
321
+ else:
322
+ return self.grouped_permute_and_compute(
323
+ x,
324
+ tokens_per_expert,
325
+ indices,
326
+ bin_ids,
327
+ expert_weights,
328
+ bins,
329
+ expert_capactiy,
330
+ top_k,
331
+ )
332
+
333
+
334
+ class dMoE(moe.MoE):
335
+
336
+ def _init_experts_mlp(self, args: Arguments):
337
+ return ParallelDroplessMLP(args)
build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/gelu.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # try:
5
+ # import stk
6
+ # except ImportError:
7
+ # import warnings
8
+ # warnings.warn(
9
+ # 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.',
10
+ # )
11
+
12
+ from .. import stk
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+
17
+
18
+ @torch.jit.script
19
+ def _gelu_backward_inplace(g, x):
20
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
21
+ ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out))
22
+ return g.mul_(ff)
23
+
24
+
25
+ def gelu_backward_(grad: stk.Matrix, x: stk.Matrix):
26
+ # NOTE: The two sparse matrices must have the same topology.
27
+ if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix):
28
+ return stk.Matrix(
29
+ x.size(),
30
+ _gelu_backward_inplace(grad.data, x.data),
31
+ x.row_indices,
32
+ x.column_indices,
33
+ x.offsets,
34
+ x.column_indices_t,
35
+ x.offsets_t,
36
+ x.block_offsets_t,
37
+ )
38
+ return _gelu_backward_inplace(grad, x)
39
+
40
+
41
+ def gelu(x: stk.Matrix):
42
+ assert isinstance(x, stk.Matrix)
43
+ return stk.Matrix(
44
+ x.size(),
45
+ F.gelu(x.data, approximate='tanh'),
46
+ x.row_indices,
47
+ x.column_indices,
48
+ x.offsets,
49
+ x.column_indices_t,
50
+ x.offsets_t,
51
+ x.block_offsets_t,
52
+ )
build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/glu.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # import stk.ops
5
+ # try:
6
+ # import stk.ops
7
+ # except ImportError:
8
+ # import warnings
9
+ # warnings.warn(
10
+ # 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.',
11
+ # )
12
+
13
+ from .. import stk
14
+
15
+ import torch
16
+
17
+ # from megablocks import grouped_gemm_util as gg
18
+ # from megablocks.layers import common, mpu
19
+ # from megablocks.layers.activation_fn import act_fn
20
+ # from megablocks.layers.arguments import Arguments
21
+ # from megablocks.layers.mlp import (
22
+ # SharedMLP,
23
+ # SparseMLP,
24
+ # create_dmoe_expert_weights,
25
+ # resolve_dtensor,
26
+ # )
27
+
28
+ from .. import grouped_gemm_util as gg
29
+ from . import common, mpu
30
+ from .activation_fn import act_fn
31
+ from .arguments import Arguments
32
+ from .mlp import (
33
+ SharedMLP,
34
+ SparseMLP,
35
+ create_dmoe_expert_weights,
36
+ resolve_dtensor,
37
+ )
38
+
39
+
40
+ class SparseGLU(SparseMLP):
41
+
42
+ def __init__(self, args: Arguments):
43
+ super().__init__(args)
44
+ self.v1 = torch.nn.Parameter(
45
+ torch.empty(
46
+ self._num_rows_per_rank,
47
+ args.hidden_size,
48
+ device=args.device,
49
+ dtype=common.dtype(args),
50
+ ),
51
+ )
52
+ with torch.no_grad():
53
+ self.v1.copy_(
54
+ create_dmoe_expert_weights(
55
+ args,
56
+ args.moe_num_experts,
57
+ args.ffn_hidden_size,
58
+ args.hidden_size,
59
+ args.init_method,
60
+ ),
61
+ )
62
+
63
+ mpu.set_expert_model_parallel_attributes(
64
+ self.v1,
65
+ self._should_set_parallelism_attribute,
66
+ )
67
+
68
+ def forward(self, x, topo):
69
+ if self.args.memory_optimized_mlp:
70
+ raise NotImplementedError(
71
+ 'Memory optimized implementation not yet supported with GLU with sparse kernels.',
72
+ )
73
+
74
+ w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2)
75
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
76
+
77
+ # Compute the GLU.
78
+ x1 = stk.ops.sdd(x, w1.t(), topo)
79
+ x2 = stk.ops.sdd(x, v1.t(), topo)
80
+
81
+ activation_fn_out = act_fn(x1, self.args.activation_fn)
82
+ x1 = stk.ops.mul(activation_fn_out, x2)
83
+
84
+ return stk.ops.dsd(x1, w2)
85
+
86
+
87
+ class MemoryOptimizedGroupedGLU(torch.autograd.Function):
88
+ """GroupedMLP with manually scheduled memory reuse."""
89
+
90
+ @staticmethod
91
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
92
+ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn):
93
+ # Cast inputs using ctx dtype from AMP
94
+ if ctx._fwd_used_autocast:
95
+ x = x.to(ctx._dtype)
96
+ w1 = w1.to(ctx._dtype)
97
+ v1 = v1.to(ctx._dtype)
98
+ w2 = w2.to(ctx._dtype)
99
+ # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k]
100
+ if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()):
101
+ raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.")
102
+
103
+ # Layer 0: x @ w1.t().
104
+ assert gg.backend is not None
105
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
106
+ v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True)
107
+
108
+ # GeLU.
109
+ activation_fn_out = activation_fn(sdd_out) * v1_out
110
+
111
+ # Layer 1: x @ w2.
112
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
113
+
114
+ # NOTE: Save the input to the layer and the activation_fn input for
115
+ # gradient computation. We'll re-compute the activation_fn forward
116
+ # pass in the backward pass to avoid materializing another
117
+ # intermediate.
118
+ ctx.x_shape = x.shape
119
+ ctx.sdd_out_shape = sdd_out.shape
120
+ ctx.dtype = x.dtype
121
+ ctx.activation_fn = activation_fn
122
+ ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out)
123
+ return dsd_out
124
+
125
+ @staticmethod
126
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
127
+ def backward(ctx, ddsd_out):
128
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
129
+ raise ValueError('Expected all MLP inputs to need grad.')
130
+
131
+ # Unpack saved tensors
132
+ # dtype = ctx.dtype
133
+ saved_tensors = ctx.saved_tensors
134
+ w1, v1, w2 = saved_tensors[:3]
135
+ batch_sizes = saved_tensors[3]
136
+ x = saved_tensors[4]
137
+ sdd_out, v1_out = saved_tensors[5:7]
138
+
139
+ # Rematerialize activation_fn output.
140
+ activation_fn = ctx.activation_fn
141
+ with torch.set_grad_enabled(True):
142
+ sdd_out.requires_grad = True
143
+ v1_out.requires_grad = True
144
+ activation_fn_out = activation_fn(sdd_out) * v1_out
145
+ activation_grad_fn = activation_fn_out.backward
146
+
147
+ # Compute dw2 with recomputed activation_fn output.
148
+ assert gg.backend is not None
149
+ dw2 = gg.backend.gmm(
150
+ activation_fn_out,
151
+ ddsd_out,
152
+ batch_sizes,
153
+ trans_a=True,
154
+ )
155
+
156
+ # Compute dactivation_fn_out.
157
+ #
158
+ # NOTE: We reuse the activation_fn_out allocation.
159
+ dactivation_fn_out = activation_fn_out
160
+ gg.backend.gmm(
161
+ ddsd_out,
162
+ w2,
163
+ batch_sizes,
164
+ trans_b=True,
165
+ c=dactivation_fn_out,
166
+ )
167
+
168
+ # Compute dsdd_out.
169
+ #
170
+ # NOTE: This reuses the dactivation_fn_out allocation.
171
+ assert activation_grad_fn is not None
172
+ activation_grad_fn(dactivation_fn_out)
173
+ dsdd_out = sdd_out.grad
174
+ dv1_out = v1_out.grad
175
+
176
+ # Compute dw1.
177
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
178
+
179
+ # Compute dv1.
180
+ dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True)
181
+
182
+ # Compute dx.
183
+ #
184
+ # NOTE: This reuses the ddsd_out allocation.
185
+ dx = ddsd_out
186
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx)
187
+ dx += gg.backend.gmm(dv1_out, v1, batch_sizes)
188
+ return dx, dw1, dv1, dw2, None, None
189
+
190
+
191
+ memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply
192
+
193
+
194
+ class GroupedGLU(SparseGLU):
195
+
196
+ def forward(self, x, tokens_per_expert):
197
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
198
+ w1, v1, w2 = (
199
+ self.scale_grad(self.w1),
200
+ self.scale_grad(self.v1),
201
+ self.scale_grad(self.w2),
202
+ )
203
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
204
+
205
+ # Re-shape the weights for the grouped GEMMs.
206
+ ne = mpu.experts_per_rank(self.args)
207
+ w1 = w1.view(ne, -1, self.args.hidden_size)
208
+ v1 = v1.view(ne, -1, self.args.hidden_size)
209
+ w2 = w2.view(ne, -1, self.args.hidden_size)
210
+
211
+ if self.args.memory_optimized_mlp:
212
+ return memory_optimized_grouped_glu(
213
+ x,
214
+ w1,
215
+ v1,
216
+ w2,
217
+ batch_sizes,
218
+ self.args.activation_fn,
219
+ )
220
+
221
+ # Compute the MLP.
222
+ assert gg.ops is not None
223
+ x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
224
+ x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True)
225
+ x1 = self.args.activation_fn(x1) * x2
226
+ return gg.ops.gmm(x1, w2, batch_sizes)
227
+
228
+
229
+ class SharedGLU(SharedMLP):
230
+ """GPU for shared expert.
231
+
232
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class
233
+ """
234
+
235
+ def __init__(self, args: Arguments):
236
+ super().__init__(args)
237
+ self.gate_proj = args.fc_cls(
238
+ args.hidden_size,
239
+ self.args.shared_expert_hidden_size,
240
+ **self.fc_kwargs,
241
+ )
242
+
243
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
244
+ return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/memory_test.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import gc
5
+
6
+ import torch
7
+ import torch.distributed as dist
8
+
9
+ # from megablocks.layers import arguments, dmoe
10
+ from . import arguments, dmoe
11
+
12
+ _TESTS = ((8, 2048, 4096, 4096, 32, 4),)
13
+
14
+
15
+ def get_tensors():
16
+ ptrs = set()
17
+ out = []
18
+ for obj in gc.get_objects():
19
+ if torch.is_tensor(obj):
20
+ if not obj.is_contiguous() or obj.data_ptr() in ptrs:
21
+ continue
22
+ out.append(obj)
23
+ ptrs.add(obj.data_ptr())
24
+ return out
25
+
26
+
27
+ def test_memory(
28
+ group,
29
+ batch_size,
30
+ sequence_length,
31
+ hidden_size,
32
+ ffn_hidden_size,
33
+ num_experts,
34
+ top_k,
35
+ ):
36
+ args = arguments.Arguments(
37
+ hidden_size=hidden_size,
38
+ ffn_hidden_size=ffn_hidden_size,
39
+ moe_num_experts=num_experts,
40
+ moe_top_k=top_k,
41
+ moe_expert_model_parallelism=True,
42
+ expert_parallel_group=group,
43
+ fp16=False,
44
+ bf16=True,
45
+ device=torch.cuda.current_device(),
46
+ )
47
+ layer = dmoe.dMoE(args).cuda()
48
+
49
+ x = torch.randn((batch_size, sequence_length, hidden_size),
50
+ device=torch.cuda.current_device(),
51
+ dtype=torch.bfloat16).requires_grad_(True)
52
+ torch.cuda.empty_cache()
53
+
54
+ # Run forward + backward.
55
+ # with torch.autograd.detect_anomaly():
56
+ out, _ = layer(x)
57
+ out.mean().backward()
58
+
59
+ # Report peak memory.
60
+ mem = torch.cuda.max_memory_allocated()
61
+ print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6))
62
+ print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),)
63
+
64
+ # Calculate weight and gradient memory usage.
65
+ weight_memory = 2 * (
66
+ layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel()
67
+ )
68
+
69
+ def grad_numel(x):
70
+ if x.grad is not None:
71
+ return x.grad.numel()
72
+ return 0
73
+
74
+ grad_memory = 2 * (
75
+ grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2)
76
+ )
77
+ weight_memory += grad_memory
78
+
79
+ print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6))
80
+ print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),)
81
+
82
+ # Manually calculate GPU memory usage from the garbage
83
+ # collector.
84
+ gc.collect()
85
+ total = 0
86
+ tensors = get_tensors()
87
+ tensors = sorted(tensors, key=lambda x: -x.numel())
88
+ for i, t in enumerate(tensors):
89
+ total += t.numel()
90
+ print(f'{i}: {t.shape}, {t.numel() * 2}')
91
+ del tensors
92
+
93
+ print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6))
94
+
95
+
96
+ if __name__ == '__main__':
97
+ assert dist.is_available()
98
+ group = dist.init_process_group(backend='nccl')
99
+ local_rank = dist.get_rank(group)
100
+ torch.cuda.set_device(local_rank)
101
+
102
+ for args in _TESTS:
103
+ test_memory(group, *args)
build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/memory_test.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ DISTRIBUTED_ARGUMENTS="\
4
+ --nproc_per_node 1 \
5
+ --nnodes 1 \
6
+ --node_rank 0 \
7
+ --master_addr localhost \
8
+ --master_port 6000"
9
+
10
+ python -m torch.distributed.launch \
11
+ ${DISTRIBUTED_ARGUMENTS} \
12
+ megablocks/layers/memory_test.py
build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/mlp.py ADDED
@@ -0,0 +1,587 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Any
5
+
6
+ # try:
7
+ # import stk
8
+ # import stk.backend.triton_kernels
9
+ # import stk.ops
10
+ # except ImportError:
11
+ # import warnings
12
+ # warnings.warn(
13
+ # 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.',
14
+ # )
15
+
16
+ from .. import stk
17
+
18
+ import torch
19
+ from packaging import version
20
+
21
+ # from megablocks import grouped_gemm_util as gg
22
+ # from megablocks.layers import common, gelu, mpu
23
+ # from megablocks.layers.activation_fn import act_fn
24
+ # from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
25
+
26
+ from .. import grouped_gemm_util as gg
27
+ from . import common, gelu, mpu
28
+ from .activation_fn import act_fn
29
+ from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
30
+
31
+ class ScaleGradient(torch.autograd.Function):
32
+
33
+ @staticmethod
34
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
35
+ def forward(ctx: Any, x: torch.Tensor, scale: float):
36
+ ctx.scale = scale
37
+ return x
38
+
39
+ @staticmethod
40
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
41
+ def backward(ctx: torch.Tensor, grad: torch.Tensor):
42
+ return grad * ctx.scale, None
43
+
44
+
45
+ scale_gradient = ScaleGradient.apply
46
+
47
+
48
+ def resolve_dtensor(weight: torch.Tensor):
49
+ if version.parse(torch.__version__) >= version.parse('2.0.0'):
50
+ from torch.distributed._tensor import DTensor
51
+ if isinstance(weight, DTensor):
52
+ return weight.to_local()
53
+ return weight
54
+
55
+
56
+ def create_moe_expert_weights(
57
+ args: Arguments,
58
+ num_experts: int,
59
+ ffn_hidden_size: int,
60
+ hidden_size: int,
61
+ init_method: InitFn,
62
+ ):
63
+ # Create the entire weight matrix such that the sampled weights will
64
+ # not vary between data parallelism and expert model parallelism for
65
+ # the same random seed.
66
+ master_weights = torch.empty(
67
+ num_experts,
68
+ ffn_hidden_size,
69
+ hidden_size,
70
+ device=args.device,
71
+ dtype=common.dtype(args),
72
+ )
73
+ init_method(master_weights)
74
+
75
+ if not args.moe_expert_model_parallelism:
76
+ return master_weights
77
+
78
+ # Calculate the amount of sharding in each dimension.
79
+ expert_sharding_degree = mpu.expert_sharding_degree(args)
80
+ hidden_sharding_degree = mpu.hidden_sharding_degree(args)
81
+
82
+ # Calculate the experts per rank.
83
+ #
84
+ # NOTE: We assign ranks to be expert parallel before going
85
+ # tensor parallel.
86
+ rank = mpu.get_expert_parallel_rank(args)
87
+ expert_rank = rank % expert_sharding_degree
88
+ num_experts_per_rank = num_experts // expert_sharding_degree
89
+ start_expert = expert_rank * num_experts_per_rank
90
+ end_expert = (expert_rank + 1) * num_experts_per_rank
91
+
92
+ # Calculate the rows per rank.
93
+ row_rank = rank // expert_sharding_degree
94
+ num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree
95
+ start_row = row_rank * num_rows_per_rank
96
+ end_row = (row_rank + 1) * num_rows_per_rank
97
+
98
+ # Slice the weight matrix to get the chunk for this rank.
99
+ with torch.no_grad():
100
+ weights = master_weights[start_expert:end_expert, start_row:end_row]
101
+ return weights
102
+
103
+
104
+ class MLP(torch.nn.Module):
105
+
106
+ def __init__(self, args: Arguments):
107
+ super().__init__()
108
+ self.args = args
109
+ # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args)
110
+ experts_per_rank = mpu.experts_per_rank(args)
111
+
112
+ self.w1 = torch.nn.Parameter(
113
+ torch.empty(
114
+ experts_per_rank,
115
+ args.hidden_size,
116
+ mpu.features_per_rank(args),
117
+ device=args.device,
118
+ dtype=common.dtype(args),
119
+ ),
120
+ )
121
+ self.w2 = torch.nn.Parameter(
122
+ torch.empty(
123
+ experts_per_rank,
124
+ mpu.features_per_rank(args),
125
+ args.hidden_size,
126
+ device=args.device,
127
+ dtype=common.dtype(args),
128
+ ),
129
+ )
130
+ mpu.set_expert_model_parallel_attributes(
131
+ self.w1,
132
+ args.moe_expert_model_parallelism,
133
+ )
134
+ mpu.set_expert_model_parallel_attributes(
135
+ self.w2,
136
+ args.moe_expert_model_parallelism,
137
+ )
138
+
139
+ # Initialize the parameters for the MLP.
140
+ #
141
+ # NOTE: It is important that we create the weight tensors prior
142
+ # to creating the master weights and slicing our the piece for
143
+ # this rank. If the master weights are created first the PyTorch
144
+ # caching allocator appears to use the same memory block for these
145
+ # and the slice which causes large increases in our peak memory
146
+ # usage.
147
+ with torch.no_grad():
148
+ w1 = create_moe_expert_weights(
149
+ args,
150
+ args.moe_num_experts,
151
+ args.ffn_hidden_size,
152
+ args.hidden_size,
153
+ args.init_method,
154
+ )
155
+ self.w1.copy_(w1.transpose(1, 2).contiguous())
156
+ self.w2.copy_(
157
+ create_moe_expert_weights(
158
+ args,
159
+ args.moe_num_experts,
160
+ args.ffn_hidden_size,
161
+ args.hidden_size,
162
+ args.output_layer_init_method,
163
+ ),
164
+ )
165
+
166
+ self.gradient_scale = None
167
+ if self.args.moe_expert_model_parallelism:
168
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
169
+
170
+ def scale_grad(self, w):
171
+ if self.gradient_scale is None:
172
+ return w
173
+ return scale_gradient(w, self.gradient_scale)
174
+
175
+ def forward(self, x):
176
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
177
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
178
+ x = torch.bmm(x, w1)
179
+ x = self.args.activation_fn(x)
180
+ return torch.bmm(x, w2)
181
+
182
+
183
+ def create_dmoe_expert_weights(
184
+ args: Arguments,
185
+ num_experts: int,
186
+ rows: int,
187
+ columns: int,
188
+ init_method: InitFn,
189
+ ):
190
+ weights = create_moe_expert_weights(
191
+ args,
192
+ num_experts,
193
+ rows,
194
+ columns,
195
+ init_method,
196
+ )
197
+ return weights.view([-1, columns])
198
+
199
+
200
+ class MemoryOptimizedMLP(torch.autograd.Function):
201
+ """Sparse MLP with manually scheduled memory reuse."""
202
+
203
+ @staticmethod
204
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
205
+ def forward(ctx, x, w1, w2, topo, activation_fn):
206
+ # Cast inputs using ctx dtype from AMP
207
+ if ctx._fwd_used_autocast:
208
+ x = x.to(ctx._dtype)
209
+ w1 = w1.to(ctx._dtype)
210
+ w2 = w2.to(ctx._dtype)
211
+ # x: [m, k], w1: [n, k], w2: [n, k]
212
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
213
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
214
+
215
+ topo_tensors = (
216
+ topo.row_indices,
217
+ topo.column_indices,
218
+ topo.offsets,
219
+ topo.column_indices_t,
220
+ topo.offsets_t,
221
+ topo.block_offsets_t,
222
+ )
223
+
224
+ # Layer 0: x @ w1.t().
225
+ sdd_out = stk.ops.sdd(x, w1.t(), topo)
226
+
227
+ # GeLU.
228
+ activation_fn_out = act_fn(sdd_out, activation_fn)
229
+
230
+ # Layer 1: x @ w2.
231
+ dsd_out = stk.ops.dsd(activation_fn_out, w2)
232
+
233
+ # NOTE: Save the input to the layer and the activation_fn input for
234
+ # gradient computation. We'll re-compute the activation_fn forward
235
+ # pass in the backward pass to avoid materializing another
236
+ # intermediate.
237
+ ctx.shape = topo.shape
238
+ ctx.x_shape = x.shape
239
+ ctx.sdd_out_shape = sdd_out.data.shape
240
+ ctx.dtype = x.dtype
241
+ ctx.activation_fn = activation_fn
242
+ ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data)
243
+ return dsd_out
244
+
245
+ @staticmethod
246
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
247
+ def backward(ctx, ddsd_out):
248
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
249
+ raise ValueError('Expected all MLP inputs to need grad.')
250
+
251
+ # unpack saved tensors
252
+ # dtype = ctx.dtype
253
+ saved_tensors = ctx.saved_tensors
254
+ w1, w2 = saved_tensors[:2]
255
+ topo_tensors = saved_tensors[2:8]
256
+ x = saved_tensors[8]
257
+ sdd_out_data = saved_tensors[9]
258
+
259
+ # rematerialize activation function output
260
+ activation_fn = ctx.activation_fn
261
+ sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors)
262
+ activation_fn_out, activation_grad_fn = act_fn(
263
+ sdd_out,
264
+ activation_fn,
265
+ return_grad_fn=True,
266
+ )
267
+
268
+ # Compute dw2 with recomputed activation_fn output.
269
+ dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out)
270
+
271
+ # Compute dactivation_fn_out.
272
+ #
273
+ # NOTE: We reuse the activation_fn_out allocation.
274
+ dactivation_fn_out = activation_fn_out
275
+ stk.backend.triton_kernels.sdd(
276
+ ddsd_out,
277
+ w2.t(),
278
+ dactivation_fn_out.shape,
279
+ dactivation_fn_out.data,
280
+ dactivation_fn_out.offsets,
281
+ dactivation_fn_out.row_indices,
282
+ dactivation_fn_out.column_indices,
283
+ )
284
+
285
+ # Compute dsdd_out.
286
+ #
287
+ # NOTE: This reuses the dactivation_fn_out allocation.
288
+ if activation_fn is DEFAULT_ACTIVATION_FN:
289
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
290
+ else:
291
+ assert activation_grad_fn is not None
292
+ activation_grad_fn(dactivation_fn_out.data)
293
+ dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors)
294
+
295
+ # Compute dw1.
296
+ dw1 = stk.ops.dsd(dsdd_out.t(), x)
297
+
298
+ # Compute dx.
299
+ #
300
+ # NOTE: This reuses the ddsd_out allocation.
301
+ stk.backend.triton_kernels.dsd(
302
+ dsdd_out.shape,
303
+ dsdd_out.data,
304
+ dsdd_out.offsets,
305
+ dsdd_out.row_indices,
306
+ dsdd_out.column_indices,
307
+ dsdd_out.offsets_t,
308
+ dsdd_out.column_indices_t,
309
+ dsdd_out.block_offsets_t,
310
+ False,
311
+ w1,
312
+ ddsd_out,
313
+ )
314
+ dx = ddsd_out
315
+ return dx, dw1, dw2, None, None
316
+
317
+
318
+ memory_optimized_mlp = MemoryOptimizedMLP.apply
319
+
320
+
321
+ class SparseMLP(torch.nn.Module):
322
+
323
+ def __init__(self, args: Arguments):
324
+ super().__init__()
325
+ self.args = args
326
+ self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args)
327
+
328
+ self.w1 = torch.nn.Parameter(
329
+ torch.empty(
330
+ self._num_rows_per_rank,
331
+ args.hidden_size,
332
+ device=args.device,
333
+ dtype=common.dtype(args),
334
+ ),
335
+ )
336
+ self.w2 = torch.nn.Parameter(
337
+ torch.empty(
338
+ self._num_rows_per_rank,
339
+ args.hidden_size,
340
+ device=args.device,
341
+ dtype=common.dtype(args),
342
+ ),
343
+ )
344
+
345
+ # Initialize the parameters for the MLP.
346
+ #
347
+ # NOTE: It is important that we create the weight tensors prior
348
+ # to creating the master weights and slicing our the piece for
349
+ # this rank. If the master weights are created first the PyTorch
350
+ # caching allocator appears to use the same memory block for these
351
+ # and the slice which causes large increases in our peak memory
352
+ # usage.
353
+ with torch.no_grad():
354
+ self.w1.copy_(
355
+ create_dmoe_expert_weights(
356
+ args,
357
+ args.moe_num_experts,
358
+ args.ffn_hidden_size,
359
+ args.hidden_size,
360
+ args.init_method,
361
+ ),
362
+ )
363
+ self.w2.copy_(
364
+ create_dmoe_expert_weights(
365
+ args,
366
+ args.moe_num_experts,
367
+ args.ffn_hidden_size,
368
+ args.hidden_size,
369
+ args.output_layer_init_method,
370
+ ),
371
+ )
372
+
373
+ self._should_set_parallelism_attribute = args.moe_expert_model_parallelism
374
+ mpu.set_expert_model_parallel_attributes(
375
+ self.w1,
376
+ self._should_set_parallelism_attribute,
377
+ )
378
+ mpu.set_expert_model_parallel_attributes(
379
+ self.w2,
380
+ self._should_set_parallelism_attribute,
381
+ )
382
+
383
+ self.gradient_scale = None
384
+ if self.args.moe_expert_model_parallelism:
385
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
386
+
387
+ def scale_grad(self, w):
388
+ if self.gradient_scale is None:
389
+ return w
390
+ return scale_gradient(w, self.gradient_scale)
391
+
392
+ def forward(self, x, topo):
393
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
394
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
395
+ if self.args.memory_optimized_mlp:
396
+ return memory_optimized_mlp(
397
+ x,
398
+ w1,
399
+ w2,
400
+ topo,
401
+ self.args.activation_fn,
402
+ )
403
+
404
+ # Compute the MLP.
405
+ x = stk.ops.sdd(x, w1.t(), topo)
406
+ activation_fn_out = act_fn(x, self.args.activation_fn)
407
+ return stk.ops.dsd(activation_fn_out, w2)
408
+
409
+
410
+ class MemoryOptimizedGroupedMLP(torch.autograd.Function):
411
+ """GroupedMLP with manually scheduled memory reuse."""
412
+
413
+ @staticmethod
414
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
415
+ def forward(ctx, x, w1, w2, batch_sizes, activation_fn):
416
+ # Cast inputs using ctx dtype from AMP
417
+ if ctx._fwd_used_autocast:
418
+ x = x.to(ctx._dtype)
419
+ w1 = w1.to(ctx._dtype)
420
+ w2 = w2.to(ctx._dtype)
421
+ # x: [m, k], w1: [n, k], w2: [n, k]
422
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
423
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
424
+
425
+ # Layer 0: x @ w1.t().
426
+ assert gg.backend is not None
427
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
428
+
429
+ # activation_fn
430
+ activation_fn_out = activation_fn(sdd_out)
431
+
432
+ # Layer 1: x @ w2.
433
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
434
+
435
+ # NOTE: Save the input to the layer and the activation_fn input for
436
+ # gradient computation. We'll re-compute the activation_fn forward
437
+ # pass in the backward pass to avoid materializing another
438
+ # intermediate.
439
+ ctx.x_shape = x.shape
440
+ ctx.sdd_out_shape = sdd_out.shape
441
+ ctx.dtype = x.dtype
442
+ ctx.activation_fn = activation_fn
443
+ ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out)
444
+ return dsd_out
445
+
446
+ @staticmethod
447
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
448
+ def backward(ctx: Any, ddsd_out: torch.Tensor):
449
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
450
+ raise ValueError('Expected all MLP inputs to need grad.')
451
+
452
+ # Unpack saved tensors
453
+ # dtype = ctx.dtype
454
+ saved_tensors = ctx.saved_tensors
455
+ w1, w2 = saved_tensors[:2]
456
+ batch_sizes = saved_tensors[2]
457
+ x = saved_tensors[3]
458
+ sdd_out = saved_tensors[4]
459
+
460
+ # Rematerialize activation_fn output.
461
+ activation_fn = ctx.activation_fn
462
+ with torch.set_grad_enabled(True):
463
+ sdd_out.requires_grad = True
464
+ activation_fn_out = activation_fn(sdd_out)
465
+ activation_grad_fn = activation_fn_out.backward
466
+
467
+ # Compute dw2 with recomputed activation_fn output.
468
+ assert gg.backend is not None
469
+ dw2 = gg.backend.gmm(
470
+ activation_fn_out,
471
+ ddsd_out,
472
+ batch_sizes,
473
+ trans_a=True,
474
+ )
475
+
476
+ # Compute dactivation_fn_out.
477
+ #
478
+ # NOTE: We reuse the activation_fn_out allocation.
479
+ dactivation_fn_out = activation_fn_out
480
+ gg.backend.gmm(
481
+ ddsd_out,
482
+ w2,
483
+ batch_sizes,
484
+ trans_b=True,
485
+ c=dactivation_fn_out,
486
+ )
487
+
488
+ # Compute dsdd_out.
489
+ #
490
+ # NOTE: This reuses the dactivation_fn_out allocation.
491
+ if activation_fn is DEFAULT_ACTIVATION_FN:
492
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
493
+ else:
494
+ assert activation_grad_fn is not None
495
+ activation_grad_fn(dactivation_fn_out)
496
+ dsdd_out = sdd_out.grad
497
+
498
+ # Compute dw1.
499
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
500
+
501
+ # Compute dx.
502
+ #
503
+ # NOTE: This reuses the ddsd_out allocation.
504
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out)
505
+ dx = ddsd_out
506
+ return dx, dw1, dw2, None, None
507
+
508
+
509
+ memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply
510
+
511
+
512
+ class GroupedMLP(SparseMLP):
513
+
514
+ def forward(self, x, tokens_per_expert):
515
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
516
+ w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2))
517
+
518
+ # Re-shape the weights for the grouped GEMMs.
519
+ ne = mpu.experts_per_rank(self.args)
520
+ w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size)
521
+ w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size)
522
+
523
+ if self.args.memory_optimized_mlp:
524
+ return memory_optimized_grouped_mlp(
525
+ x,
526
+ w1,
527
+ w2,
528
+ batch_sizes,
529
+ self.args.activation_fn,
530
+ )
531
+
532
+ # Compute the MLP.
533
+ assert gg.ops is not None
534
+ x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
535
+ x = self.args.activation_fn(x)
536
+ return gg.ops.gmm(x, w2, batch_sizes)
537
+
538
+
539
+ class SharedMLP(torch.nn.Module):
540
+ """MLP for shared expert.
541
+
542
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class
543
+ """
544
+
545
+ def __init__(self, args: Arguments):
546
+ super().__init__()
547
+ self.args = args
548
+ self.fc_kwargs: dict[str, Any] = {
549
+ 'bias': args.bias,
550
+ 'device': args.device,
551
+ }
552
+ self.fc_kwargs.update(args.fc_kwargs)
553
+
554
+ self.up_proj = args.fc_cls(
555
+ args.hidden_size,
556
+ args.shared_expert_hidden_size,
557
+ **self.fc_kwargs,
558
+ )
559
+ self.act = args.activation_fn
560
+ self.down_proj = args.fc_cls(
561
+ args.shared_expert_hidden_size,
562
+ args.hidden_size,
563
+ **self.fc_kwargs,
564
+ )
565
+ self.down_proj._is_residual = True # a flag for llm-foundry init
566
+
567
+ def add_experts_sharedexpert(
568
+ self,
569
+ shared_expert_out: torch.Tensor,
570
+ expert_out: torch.Tensor,
571
+ ) -> torch.Tensor:
572
+ # Helper function to add expert output to shared expert output
573
+ # with optional weighted sum.
574
+ if self.args.shared_expert_weighted_sum:
575
+ # enable using weighted sum for shared expert output
576
+ # wieghted by number of experts used
577
+ t_experts = self.args.moe_top_k + 1
578
+ sh_mlp_out = shared_expert_out / t_experts
579
+ return sh_mlp_out.add(
580
+ expert_out,
581
+ alpha=(self.args.moe_top_k / t_experts),
582
+ )
583
+
584
+ return shared_expert_out + expert_out
585
+
586
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
587
+ return self.down_proj(self.act(self.up_proj(x)))
build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/moe.py ADDED
@@ -0,0 +1,507 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from typing import Optional, Tuple
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.distributed as dist
8
+
9
+ # import megablocks.ops as ops
10
+ # from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry
11
+ # from megablocks.layers.all_to_all import all_to_all
12
+ # from megablocks.layers.arguments import Arguments
13
+
14
+ from ..ops import (
15
+ sort,
16
+ histogram,
17
+ inclusive_cumsum,
18
+ exclusive_cumsum,
19
+ binned_gather,
20
+ binned_scatter,
21
+ gather,
22
+ scatter,
23
+ repeat,
24
+ replicate,
25
+ )
26
+
27
+ from . import common, mlp, mpu, router, sharedexpert_registry
28
+ from .arguments import Arguments
29
+ from .all_to_all import all_to_all
30
+
31
+ _LOAD_BALANCING_LOSS = []
32
+
33
+
34
+ def save_load_balancing_loss(loss):
35
+ global _LOAD_BALANCING_LOSS
36
+ _LOAD_BALANCING_LOSS.append(loss)
37
+
38
+
39
+ def get_load_balancing_loss():
40
+ global _LOAD_BALANCING_LOSS
41
+ return _LOAD_BALANCING_LOSS
42
+
43
+
44
+ def clear_load_balancing_loss():
45
+ global _LOAD_BALANCING_LOSS
46
+ _LOAD_BALANCING_LOSS.clear()
47
+
48
+
49
+ def batched_load_balancing_loss(args: Arguments):
50
+ if args.moe_loss_weight == 0:
51
+ return 0.0
52
+
53
+ # tokens_per_expert[i].shape = (num_experts)
54
+ # expert_scores[i].shape = (tokens, num_experts)
55
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
56
+ num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size)
57
+ if args.num_layers_per_virtual_pipeline_stage is not None:
58
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
59
+
60
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
61
+ raise ValueError(
62
+ f'Expected {num_layers_per_pipeline_stage} token_per_experts '
63
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
64
+ f'{args.num_layers}\npipeline_model_parallel_size = '
65
+ f'{args.pipeline_model_parallel_size}\n'
66
+ 'num_layers_per_virtual_pipeline_stage'
67
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
68
+ )
69
+ if len(expert_scores) != num_layers_per_pipeline_stage:
70
+ raise ValueError(
71
+ f'Expected {num_layers_per_pipeline_stage} expert_scores '
72
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
73
+ f'{args.num_layers}\npipeline_model_parallel_size = '
74
+ f'{args.pipeline_model_parallel_size}\n'
75
+ 'num_layers_per_virtual_pipeline_stage'
76
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
77
+ )
78
+
79
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
80
+ assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert))
81
+
82
+ tokens = expert_scores[0].shape[0]
83
+ assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores))
84
+
85
+ # Concatenate the contributions of each layer and convert to
86
+ # the correct types and formats for the dot product.
87
+ expert_scores = torch.cat(expert_scores, dim=1)
88
+ if args.moe_lbl_in_fp32:
89
+ expert_scores = expert_scores.float()
90
+ if tokens != 0:
91
+ expert_scores = expert_scores.mean(dim=0)
92
+ else:
93
+ expert_scores = expert_scores.sum(dim=0)
94
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
95
+
96
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
97
+ assert tokens_per_expert.numel() == expected_values
98
+ assert expert_scores.numel() == expected_values
99
+
100
+ # Calculate the total scale across all factors.
101
+ #
102
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
103
+ scale_numerator = (args.moe_num_experts * args.moe_loss_weight)
104
+ scale_denominator = (args.num_layers * tokens * args.moe_top_k)
105
+ scale = scale_numerator / scale_denominator
106
+ return scale * torch.dot(tokens_per_expert, expert_scores)
107
+
108
+
109
+ # NOTE: This class defines MoE expert computation, including expert model parallel
110
+ # communication. When using FSDP on top of MegaBlocks this is the module that should
111
+ # be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model
112
+ # parallel all2all.
113
+ class ParallelMLP(torch.nn.Module):
114
+
115
+ def __init__(self, args: Arguments):
116
+ super(ParallelMLP, self).__init__()
117
+ self.args = args
118
+
119
+ # Calculate the number of experts in total and the number of experts
120
+ # owned by this rank.
121
+ # world_size = mpu.get_expert_parallel_world_size(args)
122
+ self.num_experts = args.moe_num_experts
123
+ self.top_k = self.args.moe_top_k
124
+
125
+ # Calculate the number of bits needed to represent the expert indices
126
+ # so that we can pass it to radix sort.
127
+ self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
128
+
129
+ # Expert MLP.
130
+ self.mlp = mlp.MLP(args)
131
+
132
+ self.bias: Optional[torch.Tensor]
133
+ if self.args.bias:
134
+ # Note that the output bias is not parallelized with expert
135
+ # model parallelism.
136
+ self.bias = torch.nn.Parameter(
137
+ torch.empty(
138
+ args.hidden_size,
139
+ device=args.device,
140
+ dtype=common.dtype(args),
141
+ ),
142
+ )
143
+ torch.nn.init.zeros_(self.bias)
144
+ else:
145
+ self.register_parameter('bias', None)
146
+
147
+ # Select the forward function for the operating mode.
148
+ self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once)
149
+
150
+ def expert_capacity(self, tokens: int) -> int:
151
+ world_size = mpu.get_expert_parallel_world_size(self.args)
152
+ tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts)
153
+ return int(self.args.moe_capacity_factor * tokens_per_expert)
154
+
155
+ def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor):
156
+ """Calculate the load balancing loss contribution."""
157
+ assert len(expert_scores.size()) == 2
158
+ tokens, num_experts = expert_scores.size()
159
+ assert num_experts == self.num_experts
160
+ assert len(tokens_per_expert.size()) == 1
161
+ num_experts, = tokens_per_expert.size()
162
+ assert num_experts == self.num_experts
163
+ scale = self.num_experts / (tokens * self.top_k)
164
+ return scale * torch.dot(
165
+ tokens_per_expert.to(expert_scores.dtype),
166
+ expert_scores.mean(dim=0),
167
+ )
168
+
169
+ def indices_and_bins(self,
170
+ top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
171
+ # Sort the expert ids to produce the scatter/gather
172
+ # indices for the permutation.
173
+ #
174
+ # TODO(tgale): Is it worth doing this conversion to 32-bit
175
+ # prior? Could we place the `torch.max` operation to return
176
+ # 32-bit expert indices?
177
+ top_expert = top_expert.int()
178
+ # output = ops.sort(top_expert, self.sort_end_bit)
179
+ output = sort(top_expert, self.sort_end_bit)
180
+ assert output is not None
181
+ bin_ids, indices = output
182
+
183
+ # Histogram the expert ids to identify the number of
184
+ # tokens routed to each expert.
185
+ #
186
+ # TODO(tgale): Does the sorted data produce a more favorable
187
+ # data distribution for histogram? Or is the op parallelism
188
+ # worth more?
189
+ # tokens_per_expert = ops.histogram(top_expert, self.num_experts)
190
+ tokens_per_expert = histogram(top_expert, self.num_experts)
191
+
192
+ # Calculate the bin bounds for the sorted tokens.
193
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
194
+ bins = inclusive_cumsum(tokens_per_expert, 0)
195
+ assert bins is not None
196
+ bins = bins.view(1) if not len(bins.size()) else bins
197
+
198
+ assert isinstance(indices, torch.Tensor)
199
+ assert isinstance(bin_ids, torch.Tensor)
200
+ assert isinstance(bins, torch.Tensor)
201
+ assert isinstance(tokens_per_expert, torch.Tensor)
202
+
203
+ return indices, bin_ids, bins, tokens_per_expert
204
+
205
+ def permute_and_compute(
206
+ self,
207
+ x: torch.Tensor,
208
+ tokens_per_expert: int, # unused
209
+ indices: torch.Tensor,
210
+ bin_ids: torch.Tensor, # unused
211
+ expert_weights: torch.Tensor,
212
+ bins: torch.Tensor,
213
+ expert_capacity: int,
214
+ top_k: int,
215
+ ):
216
+ # Route the tokens for MoE computation.
217
+ x = x.view(-1, x.shape[-1])
218
+ # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
219
+ output = binned_gather(x, indices, bins, expert_capacity, top_k)
220
+ assert output is not None
221
+ x = output
222
+
223
+ # Perform the expert computation. Note that we don't
224
+ # use biases for these linear operations.
225
+ x = self.mlp(x)
226
+
227
+ # Un-route the data for the MoE output.
228
+ # return ops.binned_scatter(x, indices, expert_weights, bins, top_k)
229
+ return binned_scatter(x, indices, expert_weights, bins, top_k)
230
+
231
+
232
+ def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
233
+ # x: [sl, bs, hs]
234
+ # expert_weights: [sl * bs, top-k]
235
+ # top_experts: [sl * bs, top-k]
236
+ expert_weights = expert_weights.flatten()
237
+ top_experts = top_experts.flatten()
238
+ with torch.no_grad():
239
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
240
+
241
+ # If expert_capacity is set to zero, set the number of tokens
242
+ # per expert to the maximum we need to avoid dropping tokens.
243
+ sl, bs, _ = x.size()
244
+ expert_capacity = self.expert_capacity(sl * bs)
245
+ if expert_capacity == 0:
246
+ expert_capacity = torch.max(tokens_per_expert).item()
247
+
248
+ x = self.permute_and_compute(
249
+ x,
250
+ tokens_per_expert,
251
+ indices,
252
+ bin_ids,
253
+ expert_weights,
254
+ bins,
255
+ expert_capacity,
256
+ self.top_k,
257
+ )
258
+ return x, tokens_per_expert
259
+
260
+ def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
261
+ # NOTE: This function implements the same computation as forward_once
262
+ # but with expert model parallelism.
263
+ #
264
+ # 1. Permute the tokens locally so that they are grouped by their
265
+ # expert assignments. This allows us to transfer all of the tokens
266
+ # for a remote device in one communication primitive.
267
+ #
268
+ # 2. Permute the tokens across the expert parallel devices. After
269
+ # this is completed each device has all of the tokens assigned to
270
+ # its set of experts in its local HBM.
271
+ #
272
+ # 3. Permute the tokens locally so that they are grouped by their
273
+ # expert assignement. After the distributed permutation the tokens
274
+ # are grouped by which device they came from. We re-order them
275
+ # locally to allow for efficient computation.
276
+ #
277
+ # After this series of permutations we compute the linear layers
278
+ # and then repeat these three steps in reverse to produce the final
279
+ # output.
280
+ #
281
+ # Compute the mapping of local tokens to experts.
282
+ expert_weights = expert_weights.flatten()
283
+ top_experts = top_experts.flatten()
284
+ with torch.no_grad():
285
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
286
+
287
+ # If we're sharding the experts along the hidden dimension
288
+ # multiple devices own parts of the same sets of experts.
289
+ # Replicate the token counts so every device gets the counts.
290
+ # repeated_tokens_per_expert = ops.repeat(
291
+ repeated_tokens_per_expert = repeat(
292
+ tokens_per_expert,
293
+ (mpu.hidden_sharding_degree(self.args),),
294
+ )
295
+
296
+ # Pass token count information to the device on which the
297
+ # target expert resides.
298
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,)
299
+ tpe_handle = dist.all_to_all_single(
300
+ parallel_tokens_per_expert,
301
+ repeated_tokens_per_expert,
302
+ group=self.args.expert_parallel_group,
303
+ async_op=True,
304
+ )
305
+
306
+ # Permute locally and without any padding so that tokens for each
307
+ # parallel device are stored contiguously.
308
+ #
309
+ # This view updates the shape of the tensor from [sl, bs, hs] to
310
+ # [sl * bs, hs] prior to the permutation.
311
+ x = x.view(-1, x.shape[-1])
312
+ # output = ops.gather(x, indices, bin_ids, bins, self.top_k)
313
+ output = gather(x, indices, bin_ids, bins, self.top_k)
314
+ assert output is not None
315
+ x = output
316
+
317
+ # Compute the number of tokens that will be received from each
318
+ # device and permute the input data across the devices.
319
+ with torch.no_grad():
320
+ tpe_handle.wait()
321
+ experts_per_rank = mpu.experts_per_rank(self.args)
322
+
323
+ # Reshape to [world_size, num_experts_per_rank].
324
+ world_size = mpu.get_expert_parallel_world_size(self.args)
325
+ repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank))
326
+ parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank))
327
+
328
+ # TODO(tgale): It might be faster to do this on the GPU and
329
+ # then communicate the results back to the host.
330
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1)
331
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
332
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1)
333
+
334
+ # Convert the send/recv counts to lists.
335
+ send_counts = send_counts.tolist()
336
+ recv_counts = recv_counts.tolist()
337
+ tokens_received = sum(recv_counts)
338
+
339
+ # If we're sharding the experts along the hidden dimension
340
+ # multiple devices own parts of the same sets of experts.
341
+ # Replicate the token counts so devices that share experts
342
+ # get all of the tokens assigned to them.
343
+ #
344
+ # TODO(tgale): Fuse this into the prior, local permutation.
345
+ # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
346
+ x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
347
+
348
+ # Start the cross-device permutation asynchronously so we can
349
+ # overlap communication with computation.
350
+ parallel_x, parallel_x_handle = all_to_all(
351
+ x,
352
+ recv_counts,
353
+ send_counts,
354
+ self.args.expert_parallel_group,
355
+ async_op=True,
356
+ )
357
+
358
+ with torch.no_grad():
359
+ # After we do the cross-device permutation we have the tokens on the
360
+ # correct device but not yet grouped by expert because we received
361
+ # tokens from each device as contiguous chunks. To group the tokens
362
+ # for expert computation we'll do one more local permutation. The
363
+ # rest of this torch.no_grad() scope sets up the indices and bins
364
+ # for this permutation.
365
+ # replicate_bins = ops.inclusive_cumsum(
366
+ replicate_bins = inclusive_cumsum(
367
+ parallel_tokens_per_expert.flatten(),
368
+ 0,
369
+ )
370
+ replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins)
371
+
372
+ # Construct the expert indices for the permuted tokens.
373
+ parallel_top_expert = torch.remainder(
374
+ torch.arange(
375
+ self.num_experts * mpu.hidden_sharding_degree(self.args),
376
+ dtype=torch.int32,
377
+ device=indices.device,
378
+ ),
379
+ mpu.experts_per_rank(self.args),
380
+ )
381
+ # parallel_top_expert = ops.replicate(
382
+ parallel_top_expert = replicate(
383
+ parallel_top_expert.unsqueeze(dim=0),
384
+ replicate_bins,
385
+ tokens_received,
386
+ ).flatten()
387
+
388
+ # TODO(tgale): The sort_end_bit here can be reduced.
389
+ # parallel_bin_ids, parallel_indices = ops.sort(
390
+ parallel_bin_ids, parallel_indices = sort(
391
+ parallel_top_expert,
392
+ self.sort_end_bit,
393
+ )
394
+
395
+ # Calculate the bins boundaries from the token counts.
396
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
397
+ dim=0,
398
+ dtype=torch.int,
399
+ )
400
+ # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
401
+ parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0)
402
+ parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins)
403
+
404
+ # If expert_capacity is set to zero, set the number of tokens
405
+ # per expert to the maximum we need to avoid dropping tokens.
406
+ tokens, _ = x.size()
407
+ expert_capacity = self.expert_capacity(tokens)
408
+ if expert_capacity == 0:
409
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
410
+
411
+ # Locally permute the tokens and perform the expert computation.
412
+ # Block to make sure that the cross-device permutation is complete.
413
+ if self.args.mlp_impl == 'grouped':
414
+ # GroupedMLP requires counts on CPU. We can use the tensor already
415
+ # moved to CPU for the prior all_to_all, which avoids an extra
416
+ # device synchronization.
417
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
418
+ dim=0,
419
+ dtype=torch.int,
420
+ )
421
+ parallel_x_handle.wait()
422
+ parallel_x = self.permute_and_compute(
423
+ parallel_x,
424
+ parallel_tokens_per_expert,
425
+ parallel_indices,
426
+ parallel_bin_ids,
427
+ None, # expert_weights
428
+ parallel_bins,
429
+ expert_capacity,
430
+ top_k=1,
431
+ )
432
+
433
+ # Un-permute the tokens across the devices.
434
+ x, _ = all_to_all(
435
+ parallel_x,
436
+ send_counts,
437
+ recv_counts,
438
+ self.args.expert_parallel_group,
439
+ )
440
+
441
+ # Reduce along the hidden sharding to get the final outputs.
442
+ #
443
+ # TODO(tgale): Fuse this into the following local permutation.
444
+ shape = (
445
+ mpu.hidden_sharding_degree(self.args),
446
+ -1,
447
+ self.args.hidden_size,
448
+ )
449
+ # x = ops.sum(x.view(shape), dim=0)
450
+ x = x.view(shape).sum(dim=0)
451
+
452
+ # Un-permute locally to setup for the next series of operations.
453
+ # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
454
+ x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
455
+ return x, tokens_per_expert.flatten()
456
+
457
+ def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
458
+ in_shape = x.size()
459
+
460
+ # Compute the experts.
461
+ x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts)
462
+ if self.training and self.args.moe_loss_weight > 0:
463
+ save_load_balancing_loss((tokens_per_expert, scores))
464
+ x = x.view(in_shape)
465
+ if self.bias is not None:
466
+ if self.args.return_bias:
467
+ return x, self.bias
468
+ return x + self.bias
469
+ return x
470
+
471
+
472
+ class MoE(torch.nn.Module):
473
+
474
+ def __init__(self, args: Arguments):
475
+ super(MoE, self).__init__()
476
+
477
+ # Token router.
478
+ self.router = router.LearnedRouter(args)
479
+
480
+ # Expert computation helper.
481
+ self.experts = self._init_experts_mlp(args)
482
+
483
+ self.shared_expert = None
484
+ if args.shared_expert:
485
+ # SharedExpert computation helper.
486
+ self.shared_expert = sharedexpert_registry.get(args)
487
+
488
+ def _init_experts_mlp(self, args: Arguments):
489
+ return ParallelMLP(args)
490
+
491
+ def forward(self, x: torch.Tensor):
492
+ # NOTE: If we're going to cast the activations to lower precision
493
+ # do it before we permute the tokens to save bandwidth.
494
+ x = common.cast_if_autocast_enabled(x)
495
+
496
+ # Compute the expert scores and assignments.
497
+ scores, expert_weights, top_experts = self.router(x)
498
+
499
+ # Compute the experts.
500
+ out = self.experts(x, scores, expert_weights, top_experts)
501
+ if self.shared_expert is not None:
502
+ shared_expert_out = self.shared_expert(x)
503
+ out = self.shared_expert.add_experts_sharedexpert(
504
+ shared_expert_out,
505
+ out,
506
+ )
507
+ return out
build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/mpu.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.distributed as dist
8
+
9
+ # from megablocks.layers.arguments import Arguments
10
+ from .arguments import Arguments
11
+
12
+
13
+ class MoeParam(torch.Tensor):
14
+
15
+ def __init__(self):
16
+ super().__init__(self)
17
+ self.expert_model_parallel: bool
18
+
19
+
20
+ def is_moe_param(tensor: torch.Tensor) -> bool:
21
+ return hasattr(tensor, 'expert_model_parallel')
22
+
23
+
24
+ def get_expert_parallel_world_size(args: Arguments) -> int:
25
+ return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1)
26
+
27
+
28
+ def get_expert_parallel_rank(args: Arguments) -> int:
29
+ return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0)
30
+
31
+
32
+ def set_expert_model_parallel_attributes(
33
+ tensor: torch.Tensor,
34
+ is_parallel: bool,
35
+ ):
36
+ assert not hasattr(tensor, 'expert_model_parallel')
37
+ setattr(tensor, 'expert_model_parallel', is_parallel)
38
+
39
+
40
+ def param_is_expert_model_parallel(param: MoeParam) -> bool:
41
+ return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel)
42
+
43
+
44
+ def copy_expert_model_parallel_attributes(
45
+ destination_tensor: torch.Tensor,
46
+ source_tensor: torch.Tensor,
47
+ ):
48
+ if hasattr(source_tensor, 'expert_model_parallel'):
49
+ setattr(
50
+ destination_tensor,
51
+ 'expert_model_parallel',
52
+ getattr(source_tensor, 'expert_model_parallel'),
53
+ )
54
+
55
+
56
+ def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor):
57
+ world_size = dist.get_world_size(group)
58
+ rank = dist.get_rank(group)
59
+ for i in range(world_size):
60
+ dist.barrier(group)
61
+ if i == rank:
62
+ print(f'rank = {rank}', *x)
63
+
64
+
65
+ # Helpers for expert/tensor sharding.
66
+ def expert_sharding_degree(args: Arguments) -> int:
67
+ world_size = get_expert_parallel_world_size(args)
68
+ esd = min(world_size, args.moe_num_experts)
69
+
70
+ if (args.moe_num_experts % esd) != 0:
71
+ raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',)
72
+ return esd
73
+
74
+
75
+ def hidden_sharding_degree(args: Arguments) -> int:
76
+ world_size = get_expert_parallel_world_size(args)
77
+ esd = expert_sharding_degree(args)
78
+ hsd = world_size // esd
79
+
80
+ if (args.ffn_hidden_size % hsd) != 0:
81
+ raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',)
82
+ if (esd * hsd) != world_size:
83
+ raise ValueError(
84
+ f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).",
85
+ )
86
+ return hsd
87
+
88
+
89
+ def experts_per_rank(args: Arguments) -> int:
90
+ return args.moe_num_experts // expert_sharding_degree(args)
91
+
92
+
93
+ def features_per_rank(args: Arguments) -> int:
94
+ return args.ffn_hidden_size // hidden_sharding_degree(args)
build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/router.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from typing import Any
4
+
5
+ import torch
6
+
7
+ # from megablocks.layers import common
8
+ # from megablocks.layers.arguments import Arguments
9
+ from . import common
10
+ from .arguments import Arguments
11
+
12
+ _ROUTER_LOGITS = []
13
+
14
+
15
+ def _save_router_logits(logits: torch.Tensor, args: Arguments):
16
+ if args.moe_zloss_weight == 0:
17
+ return
18
+ global _ROUTER_LOGITS
19
+ _ROUTER_LOGITS.append(logits)
20
+
21
+
22
+ def clear_router_zloss():
23
+ global _ROUTER_LOGITS
24
+ _ROUTER_LOGITS.clear()
25
+
26
+
27
+ def batched_router_zloss(args: Arguments):
28
+ global _ROUTER_LOGITS
29
+
30
+ if args.moe_zloss_weight == 0:
31
+ import warnings
32
+ warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0')
33
+ return 0
34
+
35
+ logits_per_router = _ROUTER_LOGITS
36
+
37
+ if args.moe_zloss_in_fp32:
38
+ logits_per_router = [logits.float() for logits in logits_per_router]
39
+
40
+ unscaled_zloss_per_router = torch.stack([
41
+ torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router
42
+ ])
43
+
44
+ return args.moe_zloss_weight * unscaled_zloss_per_router
45
+
46
+
47
+ # NOTE: To enable end-to-end benchmarking without convergence we
48
+ # support a flag to force the router to assign tokens uniformly
49
+ # across the experts. We do this with a custom autograd operation
50
+ # so that PyTorch still executes the full set of router operation.
51
+ class _UniformExpertAssignment(torch.autograd.Function):
52
+
53
+ @staticmethod
54
+ def forward(ctx: Any, x: torch.Tensor, num_experts: int):
55
+ out = torch.arange(x.numel(), dtype=x.dtype, device=x.device)
56
+ out = torch.remainder(out, num_experts)
57
+ return out.view(x.shape)
58
+
59
+
60
+ _uniform_expert_assignment = _UniformExpertAssignment.apply
61
+
62
+
63
+ class LearnedRouter(torch.nn.Module):
64
+
65
+ def __init__(self, args: Arguments):
66
+ super().__init__()
67
+ self.args = args
68
+
69
+ # Learned router parameters.
70
+ #
71
+ # NOTE: This weight matrix is not parallelized with expert model
72
+ # parallelism. Each device needs the entire router weight matrix
73
+ # so that it can route its batch of data correctly.
74
+ self.layer = torch.nn.Linear(
75
+ args.hidden_size,
76
+ args.moe_num_experts,
77
+ bias=False,
78
+ dtype=common.dtype(args),
79
+ device=args.device,
80
+ )
81
+ args.init_method(self.layer.weight)
82
+
83
+ def jitter(self, x: torch.Tensor):
84
+ low: float = 1.0 - self.args.moe_jitter_eps
85
+ high: float = 1.0 + self.args.moe_jitter_eps
86
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
87
+ return low + noise * (high - low)
88
+
89
+ def _top_k(self, scores: torch.Tensor):
90
+ if self.args.moe_top_k == 1:
91
+ return scores.max(dim=-1, keepdim=True)
92
+ return torch.topk(scores, self.args.moe_top_k, dim=-1)
93
+
94
+ def forward(self, x: torch.Tensor):
95
+ if self.training and self.args.moe_jitter_eps is not None:
96
+ x = x * self.jitter(x)
97
+
98
+ logits = self.layer(x.view(-1, x.shape[-1]))
99
+ _save_router_logits(logits, self.args)
100
+ scores = logits.softmax(dim=-1)
101
+ expert_weights, expert_indices = self._top_k(scores)
102
+ if self.args.moe_normalize_expert_weights:
103
+ expert_weights = expert_weights / torch.norm(
104
+ expert_weights,
105
+ p=self.args.moe_normalize_expert_weights,
106
+ dim=-1,
107
+ keepdim=True,
108
+ )
109
+
110
+ expert_indices = (
111
+ _uniform_expert_assignment(
112
+ expert_indices,
113
+ self.args.moe_num_experts,
114
+ ) if self.args.uniform_expert_assignment else expert_indices
115
+ )
116
+ return scores, expert_weights, expert_indices
build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/sharedexpert_registry.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Union
5
+
6
+ # from megablocks.layers import glu, mlp
7
+ # from megablocks.layers.arguments import Arguments
8
+ from . import glu, mlp
9
+ from .arguments import Arguments
10
+
11
+ _REGISTRY = {
12
+ 'mlp': mlp.SharedMLP,
13
+ 'glu': glu.SharedGLU,
14
+ }
15
+
16
+
17
+ def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]:
18
+ """Returns an SharedMLP for use in a dMoE instance.
19
+
20
+ Uses the provided arguments to instantiate the appropriate
21
+ SharedMLP instance.
22
+
23
+ Args:
24
+ args: propagated Arguments dataclass.
25
+
26
+ Returns:
27
+ An instantiated SharedMLP constructed using the input args.
28
+ """
29
+ if args.mlp_type not in _REGISTRY:
30
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
31
+
32
+ return _REGISTRY[args.mlp_type](args)
build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_megablocks_rocm.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ee1181fa6e502b6f3fa75ce439a68f6dccb691af130934c3dd697f3efa5cb723
3
+ size 6437768
build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_ops.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from pathlib import Path
4
+
5
+ _LIB_NAME = "_megablocks_rocm.so"
6
+
7
+
8
+ def _load_ops():
9
+ lib_path = Path(__file__).with_name(_LIB_NAME)
10
+ torch.ops.load_library(str(lib_path))
11
+ return torch.ops._megablocks_rocm
12
+
13
+
14
+ ops = _load_ops()
15
+
16
+
17
+ def add_op_namespace_prefix(op_name: str) -> str:
18
+ return f"_megablocks_rocm::{op_name}"
build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_version.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """The MegaBlocks Version."""
5
+
6
+ __version__ = '0.11.0.dev0'
build/torch28-cxx11-rocm64-x86_64-linux/megablocks/backend/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
build/torch28-cxx11-rocm64-x86_64-linux/megablocks/backend/kernels.py ADDED
@@ -0,0 +1,557 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import torch
5
+ import triton
6
+ import triton.language as tl
7
+
8
+ # Stub triton autotune when testing in a env that does not have CUDA
9
+ # this approach preserves the original code but enables testing without a GPU
10
+ if torch.cuda.is_available() is False:
11
+ import warnings
12
+
13
+ warnings.warn("CUDA is not available. Triton autotuning is disabled.")
14
+
15
+ def _no_autotune(*args, **kwargs):
16
+ def deco(fn):
17
+ return fn
18
+ return deco
19
+
20
+ triton.autotune = _no_autotune
21
+
22
+
23
+ def assert_is_tensor(x, ndim):
24
+ if x.ndim != ndim:
25
+ raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor')
26
+
27
+
28
+ def assert_is_matrix(x):
29
+ assert_is_tensor(x, 2)
30
+
31
+
32
+ def assert_is_vector(x):
33
+ if x.ndim != 1:
34
+ raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor')
35
+
36
+
37
+ def assert_equal(a, b):
38
+ if a != b:
39
+ raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',)
40
+
41
+
42
+ # a: (tokens, hidden_size), real.
43
+ # indices: (tokens * top_k), integer.
44
+ # bin_ids: (tokens * top_k), integer.
45
+ # weights: (tokens * top_k), real.
46
+ # bins: (num_experts), integer.
47
+ # padded_bins: (num_experts), integer.
48
+ @triton.autotune(
49
+ configs=[
50
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
51
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
52
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
53
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
54
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
55
+ ],
56
+ key=['NUM_COLUMNS'],
57
+ )
58
+ @triton.jit
59
+ def _padded_copy(
60
+ a,
61
+ b,
62
+ indices,
63
+ bin_ids,
64
+ weights,
65
+ bins,
66
+ padded_bins,
67
+ NUM_COLUMNS: tl.constexpr,
68
+ TOP_K: tl.constexpr,
69
+ BLOCK_X: tl.constexpr,
70
+ A_TO_B: tl.constexpr,
71
+ SCALE: tl.constexpr,
72
+ ):
73
+ # Our index into array 'a'.
74
+ index_a = tl.load(indices + tl.program_id(0))
75
+
76
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
77
+ # number of rows since they could be padded.
78
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
79
+
80
+ # Now we know what bin we're assigned to, but we need to know how
81
+ # many threadblocks were assigned to earlier bins so we can offset
82
+ # in our bin properly.
83
+ offset_in_bin = tl.program_id(0)
84
+ if bin_idx > 0:
85
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
86
+
87
+ # Load the starting index of our bin in array 'b'.
88
+ index_b = offset_in_bin
89
+ if bin_idx > 0:
90
+ index_b += tl.load(padded_bins + bin_idx - 1)
91
+
92
+ # Offset the input and output pointers.
93
+ #
94
+ # If we're going from A to B, divide the input index to copy
95
+ # the same input repeatedly. If we're going from B to A we
96
+ # need to reduce the result. Using atomics is slow, so we
97
+ # do the reduce step in a second kernel.
98
+ offset = index_a // TOP_K if A_TO_B else index_a
99
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
100
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
101
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
102
+
103
+ # Load the scale, if requested.
104
+ scale = tl.load(weights + index_a) if SCALE else 1
105
+
106
+ # Swap the pointers depending on the direction.
107
+ iptr = a if A_TO_B else b
108
+ optr = b if A_TO_B else a
109
+
110
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
111
+ for _ in range(iterations):
112
+ mask = offsets < NUM_COLUMNS
113
+ x = tl.load(iptr + offsets, mask=mask)
114
+ x = x.to(tl.float32) * scale.to(tl.float32)
115
+
116
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
117
+
118
+ offsets += BLOCK_X
119
+
120
+
121
+ def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k):
122
+ # Validate the input shapes.
123
+ assert_is_matrix(x)
124
+ assert_is_vector(indices)
125
+ assert_is_vector(bin_ids)
126
+ assert_is_vector(bins)
127
+ assert_is_vector(padded_bins)
128
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
129
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
130
+ assert_equal(bins.size(), padded_bins.size())
131
+
132
+ if weights is not None:
133
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
134
+
135
+ # NOTE: Because of the padding, the output size is dynamic.
136
+ # We load the final padded bin bound to get the output rows.
137
+ output_rows = padded_bins[-1].cpu().item()
138
+ out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
139
+ _padded_copy[(indices.shape[0],)](
140
+ x,
141
+ out,
142
+ indices,
143
+ bin_ids,
144
+ weights,
145
+ bins,
146
+ padded_bins,
147
+ NUM_COLUMNS=x.shape[1],
148
+ A_TO_B=True,
149
+ TOP_K=top_k,
150
+ SCALE=weights is not None,
151
+ )
152
+ return out
153
+
154
+
155
+ def gather(x, indices, bin_ids, weights, bins, top_k):
156
+ # Validate the input shapes.
157
+ assert_is_matrix(x)
158
+ assert_is_vector(indices)
159
+ assert_is_vector(bin_ids)
160
+ assert_is_vector(bins)
161
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
162
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
163
+
164
+ if weights is not None:
165
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
166
+
167
+ # NOTE: There is no padding so the output rows equals the
168
+ # input rows multiplied by top_k.
169
+ output_rows = x.shape[0] * top_k
170
+ out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
171
+ _padded_copy[(indices.shape[0],)](
172
+ x,
173
+ out,
174
+ indices,
175
+ bin_ids,
176
+ weights,
177
+ bins,
178
+ bins,
179
+ NUM_COLUMNS=x.shape[1],
180
+ A_TO_B=True,
181
+ TOP_K=top_k,
182
+ SCALE=weights is not None,
183
+ )
184
+ return out
185
+
186
+
187
+ def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k):
188
+ # Validate the input shapes.
189
+ assert_is_matrix(x)
190
+ assert_is_vector(indices)
191
+ assert_is_vector(bin_ids)
192
+ assert_is_vector(bins)
193
+ assert_is_vector(padded_bins)
194
+ assert_equal(indices.shape[0], bin_ids.shape[0])
195
+ assert_equal(bins.size(), padded_bins.size())
196
+
197
+ if weights is not None:
198
+ assert_equal(indices.shape[0], weights.shape[0])
199
+
200
+ tokens = indices.shape[0] // top_k
201
+ out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device)
202
+ _padded_copy[(indices.shape[0],)](
203
+ out,
204
+ x,
205
+ indices,
206
+ bin_ids,
207
+ weights,
208
+ bins,
209
+ padded_bins,
210
+ NUM_COLUMNS=x.shape[1],
211
+ A_TO_B=False,
212
+ TOP_K=top_k,
213
+ SCALE=weights is not None,
214
+ )
215
+
216
+ # Reduce along the top-k dimension, if needed.
217
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1])
218
+
219
+
220
+ def scatter(x, indices, bin_ids, weights, bins, top_k):
221
+ return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k)
222
+
223
+
224
+ # x: (tokens, top_k, hidden_size), real
225
+ # grad: (tokens, hidden_size), real.
226
+ # wgrad: (tokens, top_k), real.
227
+ # indices: (tokens * top_k), integer.
228
+ # bin_ids: (tokens * top_k), integer.
229
+ # bins: (num_experts), integer.
230
+ # padded_bins: (num_experts), integer.
231
+ @triton.autotune(
232
+ configs=[
233
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
234
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
235
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
236
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
237
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
238
+ ],
239
+ key=['NUM_COLUMNS'],
240
+ )
241
+ @triton.jit
242
+ def _padded_copy_wgrad(
243
+ x,
244
+ grad,
245
+ wgrad,
246
+ indices,
247
+ bin_ids,
248
+ bins,
249
+ padded_bins,
250
+ NUM_COLUMNS: tl.constexpr,
251
+ TOP_K: tl.constexpr,
252
+ BLOCK_X: tl.constexpr,
253
+ ):
254
+ # Our index into 'tokens * top_k'.
255
+ index_out = tl.load(indices + tl.program_id(0))
256
+
257
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
258
+ # number of rows since they could be padded.
259
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
260
+
261
+ # Now we know what bin we're assigned to, but we need to know how
262
+ # many threadblocks were assigned to earlier bins so we can offset
263
+ # in our bin properly.
264
+ offset_in_bin = tl.program_id(0)
265
+ if bin_idx > 0:
266
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
267
+
268
+ # Load the starting index of our bin in array 'x'.
269
+ index_x = offset_in_bin
270
+ if bin_idx > 0:
271
+ index_x += tl.load(padded_bins + bin_idx - 1)
272
+
273
+ # Offset the input and output pointers.
274
+ wgrad += index_out
275
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
276
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
277
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
278
+
279
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
280
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
281
+ for _ in range(iterations):
282
+ mask = offsets < NUM_COLUMNS
283
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
284
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
285
+ acc += data * scale
286
+ offsets += BLOCK_X
287
+
288
+ # Reduce to get the final result and store.
289
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
290
+ tl.store(wgrad, out)
291
+
292
+
293
+ def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k):
294
+ # Validate the input shapes.
295
+ assert_is_matrix(x)
296
+ assert_is_matrix(grad)
297
+ assert_is_vector(indices)
298
+ assert_is_vector(bin_ids)
299
+ assert_is_vector(bins)
300
+ assert_is_vector(padded_bins)
301
+ assert_equal(indices.shape[0], bin_ids.shape[0])
302
+ assert_equal(bins.size(), padded_bins.size())
303
+
304
+ tokens = indices.shape[0] // top_k
305
+ out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device)
306
+ _padded_copy_wgrad[(indices.shape[0],)](
307
+ x,
308
+ grad,
309
+ out,
310
+ indices,
311
+ bin_ids,
312
+ bins,
313
+ padded_bins,
314
+ NUM_COLUMNS=x.shape[1],
315
+ TOP_K=top_k,
316
+ )
317
+ return out
318
+
319
+
320
+ def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k):
321
+ return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k)
322
+
323
+
324
+ # a: (tokens, hidden_size), real.
325
+ # b: (num_experts, expert_capacity, num_columns), real.
326
+ # indices: (tokens * top_k), integer.
327
+ # weights: (tokens * top_k), real.
328
+ # bins: (num_experts), integer.
329
+ @triton.autotune(
330
+ configs=[
331
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
332
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
333
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
334
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
335
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
336
+ ],
337
+ key=['NUM_COLUMNS'],
338
+ )
339
+ @triton.jit
340
+ def _binned_copy(
341
+ a,
342
+ b,
343
+ num_experts,
344
+ expert_capacity,
345
+ indices,
346
+ weights,
347
+ bins,
348
+ NUM_COLUMNS: tl.constexpr,
349
+ TOP_K: tl.constexpr,
350
+ BLOCK_X: tl.constexpr,
351
+ A_TO_B: tl.constexpr,
352
+ SCALE: tl.constexpr,
353
+ ):
354
+ # Load our indices into the output.
355
+ expert_idx = tl.program_id(0)
356
+ entry_idx = tl.program_id(1)
357
+
358
+ # Calculate our offset into the output.
359
+ index_b = expert_idx * expert_capacity + entry_idx
360
+
361
+ # Load the index bounds for our bin and calculate
362
+ # the number of tokens assigned to our expert.
363
+ start = 0
364
+ if expert_idx > 0:
365
+ start = tl.load(bins + expert_idx - 1)
366
+ end = tl.load(bins + expert_idx)
367
+ num_tokens = end - start
368
+
369
+ # Calculate our offset into the input. If we don't
370
+ # have an input exit early.
371
+ if entry_idx >= num_tokens:
372
+ return
373
+ index_a = tl.load(indices + start + entry_idx)
374
+
375
+ # Offset the input and output pointers.
376
+ #
377
+ # If we're going from A to B, divide the input index to copy
378
+ # the same input repeatedly. If we're going from B to A we
379
+ # need to reduce the result. Using atomics is slow, so we
380
+ # do the reduce step in a second kernel.
381
+ offset = index_a // TOP_K if A_TO_B else index_a
382
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
383
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
384
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
385
+
386
+ # Load the scale, if requested.
387
+ scale = tl.load(weights + index_a) if SCALE else 1
388
+
389
+ # Swap the pointers depending on the direction.
390
+ #
391
+ # NOTE: We need to zero the output in both directions.
392
+ iptr = a if A_TO_B else b
393
+ optr = b if A_TO_B else a
394
+
395
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
396
+ for _ in range(iterations):
397
+ mask = offsets < NUM_COLUMNS
398
+ x = tl.load(iptr + offsets, mask=mask)
399
+ x = x.to(tl.float32) * scale.to(tl.float32)
400
+
401
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
402
+
403
+ offsets += BLOCK_X
404
+
405
+
406
+ def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
407
+ # Validate the input shapes.
408
+ assert_is_matrix(x)
409
+ assert_is_vector(indices)
410
+ assert_is_vector(bins)
411
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
412
+
413
+ if weights is not None:
414
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
415
+
416
+ num_experts = bins.shape[0]
417
+ out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
418
+
419
+ _binned_copy[(num_experts, expert_capacity)](
420
+ x,
421
+ out,
422
+ num_experts,
423
+ expert_capacity,
424
+ indices,
425
+ weights,
426
+ bins,
427
+ NUM_COLUMNS=x.shape[1],
428
+ A_TO_B=True,
429
+ TOP_K=top_k,
430
+ SCALE=weights is not None,
431
+ )
432
+ return out
433
+
434
+
435
+ def binned_scatter(x, indices, weights, bins, top_k):
436
+ # Validate the input shapes.
437
+ assert_is_tensor(x, 3)
438
+ assert_is_vector(indices)
439
+ assert_is_vector(bins)
440
+ assert_equal(bins.shape[0], x.shape[0])
441
+
442
+ if weights is not None:
443
+ assert_equal(indices.shape[0], weights.shape[0])
444
+
445
+ num_experts, expert_capacity, hidden_size = x.shape
446
+ tokens = indices.shape[0] // top_k
447
+ out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
448
+ _binned_copy[(num_experts, expert_capacity)](
449
+ out,
450
+ x,
451
+ num_experts,
452
+ expert_capacity,
453
+ indices,
454
+ weights,
455
+ bins,
456
+ NUM_COLUMNS=hidden_size,
457
+ A_TO_B=False,
458
+ TOP_K=top_k,
459
+ SCALE=weights is not None,
460
+ )
461
+
462
+ # Reduce along the top-k dimension, if needed.
463
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size)
464
+
465
+
466
+ # a: (tokens, hidden_size), real.
467
+ # b: (num_experts, expert_capacity, num_columns), real.
468
+ # indices: (tokens * top_k), integer.
469
+ # weights: (tokens * top_k), real.
470
+ # bins: (num_experts), integer.
471
+ @triton.autotune(
472
+ configs=[
473
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
474
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
475
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
476
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
477
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
478
+ ],
479
+ key=['NUM_COLUMNS'],
480
+ )
481
+ @triton.jit
482
+ def _binned_copy_wgrad(
483
+ x,
484
+ grad,
485
+ wgrad,
486
+ num_experts,
487
+ expert_capacity,
488
+ indices,
489
+ bins,
490
+ NUM_COLUMNS: tl.constexpr,
491
+ TOP_K: tl.constexpr,
492
+ BLOCK_X: tl.constexpr,
493
+ ):
494
+ # Load our indices into the output.
495
+ expert_idx = tl.program_id(0)
496
+ entry_idx = tl.program_id(1)
497
+
498
+ # Calculate our offset into the output.
499
+ index_x = expert_idx * expert_capacity + entry_idx
500
+
501
+ # Load the index bounds for our bin and calculate
502
+ # the number of tokens assigned to our expert.
503
+ start = 0
504
+ if expert_idx > 0:
505
+ start = tl.load(bins + expert_idx - 1)
506
+ end = tl.load(bins + expert_idx)
507
+ num_tokens = end - start
508
+
509
+ # Calculate our offset into the input. If we don't
510
+ # have an input exit early.
511
+ if entry_idx >= num_tokens:
512
+ return
513
+ index_out = tl.load(indices + start + entry_idx)
514
+
515
+ # Offset the input and output pointers.
516
+ wgrad += index_out
517
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
518
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
519
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
520
+
521
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
522
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
523
+ for _ in range(iterations):
524
+ mask = offsets < NUM_COLUMNS
525
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
526
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
527
+ acc += data * scale
528
+ offsets += BLOCK_X
529
+
530
+ # Reduce to get the final result and store.
531
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
532
+ tl.store(wgrad, out)
533
+
534
+
535
+ def binned_scatter_wgrad(x, grad, indices, bins, top_k):
536
+ # Validate the input shapes.
537
+ assert_is_tensor(x, 3)
538
+ assert_is_matrix(grad)
539
+ assert_is_vector(indices)
540
+ assert_is_vector(bins)
541
+ assert_equal(bins.shape[0], x.shape[0])
542
+
543
+ num_experts, expert_capacity, hidden_size = x.shape
544
+ tokens = indices.shape[0] // top_k
545
+ out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
546
+ _binned_copy_wgrad[(num_experts, expert_capacity)](
547
+ x,
548
+ grad,
549
+ out,
550
+ num_experts,
551
+ expert_capacity,
552
+ indices,
553
+ bins,
554
+ NUM_COLUMNS=hidden_size,
555
+ TOP_K=top_k,
556
+ )
557
+ return out
build/torch28-cxx11-rocm64-x86_64-linux/megablocks/bak.__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from megablocks_moe.megablocks import (
2
+ MoE,
3
+ dMoE,
4
+ get_load_balancing_loss,
5
+ ParallelMLP,
6
+ ParallelDroplessMLP,
7
+ SparseMLP,
8
+ MLP,
9
+ SparseGLU,
10
+ Arguments,
11
+ )
12
+
13
+ __all__ = [
14
+ "MoE",
15
+ "dMoE",
16
+ "get_load_balancing_loss",
17
+ "ParallelMLP",
18
+ "ParallelDroplessMLP",
19
+ "SparseMLP",
20
+ "MLP",
21
+ "SparseGLU",
22
+ "Arguments",
23
+ ]
build/torch28-cxx11-rocm64-x86_64-linux/megablocks/benchmark_util.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+
8
+ def log_benchmark(name, arguments, time, std):
9
+ print('=' * 60)
10
+ print(f'{name} Benchmark')
11
+ print('Benchmark Parameters:')
12
+ for (key, value) in arguments.items():
13
+ print(f'{key} = {value}')
14
+ print('Results:')
15
+ print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std))
16
+ print('=' * 60)
17
+
18
+
19
+ def benchmark_function(fn, iterations=100, warmup=10):
20
+ # Warmup iterations.
21
+ for _ in range(warmup):
22
+ fn()
23
+
24
+ times = []
25
+ for i in range(iterations):
26
+ start = torch.cuda.Event(enable_timing=True)
27
+ end = torch.cuda.Event(enable_timing=True)
28
+
29
+ start.record()
30
+ fn()
31
+ end.record()
32
+
33
+ torch.cuda.synchronize()
34
+ times.append(start.elapsed_time(end))
35
+ return np.mean(times), np.std(times)
build/torch28-cxx11-rocm64-x86_64-linux/megablocks/grouped_gemm/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from . import ops
2
+ from . import backend
build/torch28-cxx11-rocm64-x86_64-linux/megablocks/grouped_gemm/backend.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # NOTE: Torch needs to be imported before the custom
2
+ # extensions. Otherwise libc10.so cannot be found.
3
+ import torch
4
+
5
+ # # TODO(tgale): Wrap this in a try-block with better
6
+ # # error message and instructions for building the
7
+ # # c++ operations.
8
+ # import grouped_gemm_backend as backend
9
+
10
+ # We import the backend operations from the megablocks package as
11
+ # grouped_gemm is vendored in megablocks in this repository.
12
+ # from ... import _ops as backend
13
+ # from megablocks._ops import ops as backend # type: ignore
14
+ from .._ops import ops as backend # type: ignore
15
+
16
+ def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
17
+ assert not (trans_a and trans_b)
18
+ assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
19
+ assert a.ndim == 2, "Expected 2d tensor for 'a'"
20
+ assert b.ndim == (2 if trans_a else 3)
21
+
22
+ shape = (
23
+ (batch_sizes.shape[0], a.shape[1], b.shape[1])
24
+ if trans_a else
25
+ (a.shape[0], (b.shape[1] if trans_b else b.shape[2]))
26
+ )
27
+ return torch.empty(*shape, device=a.device, dtype=a.dtype)
28
+
29
+ def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
30
+ if c is None:
31
+ c = _allocate_output(a, b, batch_sizes, trans_a, trans_b)
32
+ backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
33
+ return c
build/torch28-cxx11-rocm64-x86_64-linux/megablocks/grouped_gemm/ops.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import backend
2
+ import torch
3
+
4
+
5
+ class GroupedGemm(torch.autograd.Function):
6
+
7
+ @staticmethod
8
+ def forward(ctx, a, b, batch_sizes, trans_b):
9
+ ctx.save_for_backward(a, b, batch_sizes)
10
+ ctx.trans_b = trans_b
11
+ return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
12
+
13
+ @staticmethod
14
+ def backward(ctx, grad):
15
+ grad = grad.contiguous()
16
+ a, b, batch_sizes = ctx.saved_tensors
17
+ trans_b = ctx.trans_b
18
+
19
+ agrad = None
20
+ if ctx.needs_input_grad[0]:
21
+ agrad = backend.gmm(
22
+ grad, b, batch_sizes, trans_a=False, trans_b=not trans_b)
23
+
24
+ bgrad = None
25
+ if ctx.needs_input_grad[1]:
26
+ lhs, rhs = (grad, a) if trans_b else (a, grad)
27
+ bgrad = backend.gmm(
28
+ lhs, rhs, batch_sizes, trans_a=True, trans_b=False)
29
+ return agrad, bgrad, None, None
30
+
31
+
32
+ def gmm(a, b, batch_sizes, trans_b=False):
33
+ return GroupedGemm.apply(a, b, batch_sizes, trans_b)
build/torch28-cxx11-rocm64-x86_64-linux/megablocks/grouped_gemm_util.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ import warnings
4
+
5
+ _grouped_gemm_is_available: bool = False
6
+ try:
7
+ # import grouped_gemm
8
+ pass
9
+ _grouped_gemm_is_available = True
10
+ except ImportError as error:
11
+ warnings.warn('Grouped GEMM not available.')
12
+
13
+
14
+ def grouped_gemm_is_available():
15
+ return _grouped_gemm_is_available
16
+
17
+
18
+ def assert_grouped_gemm_is_available():
19
+ msg = (
20
+ 'Grouped GEMM not available. Please run '
21
+ '`pip install git+https://github.com/tgale96/grouped_gemm@main`.',
22
+ )
23
+ assert _grouped_gemm_is_available, msg
24
+
25
+
26
+ # backend = grouped_gemm.backend if grouped_gemm_is_available() else None
27
+ # ops = grouped_gemm.ops if grouped_gemm_is_available() else None
28
+
29
+
30
+ from .grouped_gemm import backend as ops
31
+ from .grouped_gemm import ops as backend
build/torch28-cxx11-rocm64-x86_64-linux/megablocks/layers.py ADDED
@@ -0,0 +1,1225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+
4
+ from typing import Optional, Any, TYPE_CHECKING
5
+
6
+ from . import _layers
7
+ from . import ops
8
+
9
+ # Conditional import for meta kernel registration
10
+ if TYPE_CHECKING:
11
+
12
+ def register_fake(fn):
13
+ return lambda name: fn
14
+
15
+ else:
16
+ try:
17
+ from torch.library import register_fake
18
+ except ImportError:
19
+ try:
20
+ from torch.library import impl_abstract as register_fake
21
+ except ImportError:
22
+ # Fallback for older PyTorch versions
23
+ def register_fake(op_name):
24
+ def decorator(fn):
25
+ return fn
26
+
27
+ return decorator
28
+
29
+
30
+ # Meta kernel implementations for torch.compile compatibility
31
+ def _install_meta_kernels():
32
+ """Install meta kernels for existing MegaBlocks operations"""
33
+
34
+ # Create wrapper functions that check for compilation and return meta tensors
35
+
36
+ # Patch ops.sort
37
+ if hasattr(ops, "sort"):
38
+ original_sort = ops.sort
39
+
40
+ def sort_with_meta(x, end_bit=None):
41
+ if torch.compiler.is_compiling():
42
+ print("Using meta kernel for sort")
43
+ # Meta implementation - return tensors with correct shape/dtype/device
44
+ return torch.empty_like(x), torch.empty_like(x)
45
+ # print("Using original sort kernel")
46
+ return original_sort(x, end_bit)
47
+
48
+ ops.sort = sort_with_meta
49
+
50
+ # Patch ops.histogram
51
+ if hasattr(ops, "histogram"):
52
+ original_histogram = ops.histogram
53
+
54
+ def histogram_with_meta(x, max_val):
55
+ if torch.compiler.is_compiling():
56
+ # Meta implementation
57
+ return torch.empty((max_val,), dtype=torch.int32, device=x.device)
58
+ return original_histogram(x, max_val)
59
+
60
+ ops.histogram = histogram_with_meta
61
+
62
+ # Patch ops.inclusive_cumsum
63
+ if hasattr(ops, "inclusive_cumsum"):
64
+ original_inclusive_cumsum = ops.inclusive_cumsum
65
+
66
+ def inclusive_cumsum_with_meta(x, dim):
67
+ if torch.compiler.is_compiling():
68
+ # Meta implementation
69
+ return torch.empty_like(x)
70
+ return original_inclusive_cumsum(x, dim)
71
+
72
+ ops.inclusive_cumsum = inclusive_cumsum_with_meta
73
+
74
+ # Patch ops.binned_gather
75
+ if hasattr(ops, "binned_gather"):
76
+ original_binned_gather = ops.binned_gather
77
+
78
+ def binned_gather_with_meta(x, indices, bins, bin_size, top_k):
79
+ if torch.compiler.is_compiling():
80
+ # Meta implementation - output shape based on bin_size
81
+ if x.dim() >= 2:
82
+ hidden_size = x.size(-1)
83
+ return torch.empty(
84
+ (bin_size, x.size(1), hidden_size),
85
+ dtype=x.dtype,
86
+ device=x.device,
87
+ )
88
+ else:
89
+ return torch.empty((bin_size,), dtype=x.dtype, device=x.device)
90
+ return original_binned_gather(x, indices, bins, bin_size, top_k)
91
+
92
+ ops.binned_gather = binned_gather_with_meta
93
+
94
+ # Patch ops.binned_scatter
95
+ if hasattr(ops, "binned_scatter"):
96
+ original_binned_scatter = ops.binned_scatter
97
+
98
+ def binned_scatter_with_meta(x, indices, weights, bins, top_k):
99
+ if torch.compiler.is_compiling():
100
+ # Meta implementation - typically reduces to 2D
101
+ if x.dim() >= 3:
102
+ return torch.empty(
103
+ (x.size(1), x.size(2)), dtype=x.dtype, device=x.device
104
+ )
105
+ else:
106
+ return torch.empty_like(x)
107
+ return original_binned_scatter(x, indices, weights, bins, top_k)
108
+
109
+ ops.binned_scatter = binned_scatter_with_meta
110
+
111
+ # Patch ops.gather
112
+ if hasattr(ops, "gather"):
113
+ original_gather = ops.gather
114
+
115
+ def gather_with_meta(x, indices, bin_ids, bins, top_k):
116
+ if torch.compiler.is_compiling():
117
+ # Meta implementation
118
+ if x.dim() >= 2:
119
+ hidden_size = x.size(-1)
120
+ return torch.empty(
121
+ (indices.numel(), hidden_size), dtype=x.dtype, device=x.device
122
+ )
123
+ else:
124
+ return torch.empty(indices.shape, dtype=x.dtype, device=x.device)
125
+ return original_gather(x, indices, bin_ids, bins, top_k)
126
+
127
+ ops.gather = gather_with_meta
128
+
129
+ # Patch ops.scatter
130
+ if hasattr(ops, "scatter"):
131
+ original_scatter = ops.scatter
132
+
133
+ def scatter_with_meta(x, indices, bin_ids, weights, bins, top_k):
134
+ if torch.compiler.is_compiling():
135
+ # Meta implementation - restore sequence shape
136
+ seq_len = (
137
+ indices.size(0) // top_k
138
+ if indices.numel() > 0 and top_k > 0
139
+ else x.size(0)
140
+ )
141
+ if x.dim() >= 2:
142
+ return torch.empty(
143
+ (seq_len, x.size(-1)), dtype=x.dtype, device=x.device
144
+ )
145
+ else:
146
+ return torch.empty((seq_len,), dtype=x.dtype, device=x.device)
147
+ return original_scatter(x, indices, bin_ids, weights, bins, top_k)
148
+
149
+ ops.scatter = scatter_with_meta
150
+
151
+ # Patch ops.replicate
152
+ if hasattr(ops, "replicate"):
153
+ original_replicate = ops.replicate
154
+
155
+ def replicate_with_meta(x, bins, num_outputs):
156
+ if torch.compiler.is_compiling():
157
+ # Meta implementation
158
+ return torch.empty(
159
+ (x.shape[0], num_outputs), dtype=x.dtype, device=x.device
160
+ )
161
+ return original_replicate(x, bins, num_outputs)
162
+
163
+ ops.replicate = replicate_with_meta
164
+
165
+ # Patch ops.repeat (if it's a regular function)
166
+ if hasattr(ops, "repeat"):
167
+ original_repeat = ops.repeat
168
+
169
+ def repeat_with_meta(x, repeats):
170
+ if torch.compiler.is_compiling():
171
+ # Meta implementation
172
+ if isinstance(repeats, (tuple, list)):
173
+ new_shape = list(x.shape)
174
+ for i, rep in enumerate(repeats):
175
+ if i < len(new_shape):
176
+ new_shape[i] *= rep
177
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
178
+ else:
179
+ new_shape = [x.size(0) * repeats] + list(x.shape[1:])
180
+ return torch.empty(new_shape, dtype=x.dtype, device=x.device)
181
+ return original_repeat(x, repeats)
182
+
183
+ ops.repeat = repeat_with_meta
184
+
185
+
186
+ # Install meta kernels on import
187
+ try:
188
+ _install_meta_kernels()
189
+ except Exception as e:
190
+ # If meta kernel installation fails, continue without them
191
+ # torch.compile may not work but the library will still function
192
+ import warnings
193
+
194
+ warnings.warn(
195
+ f"Failed to install meta kernels for torch.compile support: {e}", UserWarning
196
+ )
197
+
198
+
199
+ # Set the expert model parallel attributes on a tensor
200
+ def set_expert_model_parallel_attributes(
201
+ tensor: torch.Tensor,
202
+ is_parallel: bool,
203
+ ):
204
+ assert not hasattr(tensor, "expert_model_parallel")
205
+ setattr(tensor, "expert_model_parallel", is_parallel)
206
+
207
+
208
+ # Get the expert model parallel attributes from a tensor
209
+ def expert_sharding_degree(
210
+ world_size: int,
211
+ moe_num_experts: int,
212
+ ) -> int:
213
+ esd = min(world_size, moe_num_experts)
214
+ if (moe_num_experts % esd) != 0:
215
+ raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.")
216
+ return esd
217
+
218
+
219
+ # Calculate the hidden sharding degree based on world size and expert sharding degree
220
+ def hidden_sharding_degree(
221
+ world_size: int,
222
+ moe_num_experts: int,
223
+ ffn_hidden_size: int,
224
+ ) -> int:
225
+ esd = expert_sharding_degree(world_size, moe_num_experts)
226
+ hsd = world_size // esd
227
+ if (ffn_hidden_size % hsd) != 0:
228
+ raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.")
229
+ if (esd * hsd) != world_size:
230
+ raise ValueError(
231
+ f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})."
232
+ )
233
+ return hsd
234
+
235
+
236
+ # Calculate the number of experts per rank based on world size and expert sharding degree
237
+ def experts_per_rank(
238
+ moe_num_experts: int,
239
+ world_size: int,
240
+ ) -> int:
241
+ return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts)
242
+
243
+
244
+ # Calculate the number of features per rank based on ffn hidden size and hidden sharding degree
245
+ def features_per_rank(
246
+ ffn_hidden_size: int, world_size: int, moe_num_experts: int
247
+ ) -> int:
248
+ return ffn_hidden_size // hidden_sharding_degree(
249
+ world_size, moe_num_experts, ffn_hidden_size
250
+ )
251
+
252
+
253
+ # Apply jitter to the input tensor
254
+ def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
255
+ low = 1.0 - moe_jitter_eps
256
+ high = 1.0 + moe_jitter_eps
257
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
258
+ return x * (low + noise * (high - low))
259
+
260
+
261
+ # Compute the top-k scores from the logits
262
+ def compute_top_k(scores: torch.Tensor, moe_top_k: int):
263
+ if moe_top_k == 1:
264
+ return scores.max(dim=-1, keepdim=True)
265
+ return torch.topk(scores, moe_top_k, dim=-1)
266
+
267
+
268
+ # Route tokens to experts and compute expert weights and indices
269
+ def route_tokens(
270
+ x: torch.Tensor,
271
+ router_weight: torch.Tensor,
272
+ router_bias: torch.Tensor,
273
+ moe_top_k: int,
274
+ moe_num_experts: int,
275
+ moe_jitter_eps: float = None,
276
+ moe_normalize_expert_weights: int = None,
277
+ uniform_expert_assignment: bool = False,
278
+ training: bool = False,
279
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
280
+ if training and moe_jitter_eps is not None:
281
+ x = apply_jitter(x, moe_jitter_eps)
282
+
283
+ x_flat = x.view(-1, x.shape[-1])
284
+ logits = torch.nn.functional.linear(x_flat, router_weight, router_bias)
285
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
286
+ expert_weights = expert_weights.softmax(dim=-1)
287
+ if moe_normalize_expert_weights is not None:
288
+ expert_weights = expert_weights / torch.norm(
289
+ expert_weights,
290
+ p=moe_normalize_expert_weights,
291
+ dim=-1,
292
+ keepdim=True,
293
+ )
294
+ if uniform_expert_assignment:
295
+ expert_indices = _layers.router._uniform_expert_assignment(
296
+ expert_indices,
297
+ moe_num_experts,
298
+ )
299
+
300
+ return logits, expert_weights, expert_indices
301
+
302
+
303
+ # Scale the gradient of the weights
304
+ def scale_grad(
305
+ w: torch.Tensor,
306
+ gradient_scale: Optional[float] = None,
307
+ ) -> torch.Tensor:
308
+ if gradient_scale is None:
309
+ return w
310
+ return _layers.mlp.scale_gradient(w, gradient_scale)
311
+
312
+
313
+ # Forward pass for the MLP layer
314
+ def mlp_forward(
315
+ x: torch.Tensor,
316
+ w1: torch.Tensor,
317
+ w2: torch.Tensor,
318
+ w1_bias: torch.Tensor,
319
+ w2_bias: torch.Tensor,
320
+ gradient_scale: Optional[float] = None,
321
+ alpha: float = 1.702,
322
+ limit: float = 7.0,
323
+ ):
324
+ # Scale weights
325
+ w1 = scale_grad(w1, gradient_scale)
326
+ w2 = scale_grad(w2, gradient_scale)
327
+ w1_bias = scale_grad(w1_bias, gradient_scale)
328
+ w2_bias = scale_grad(w2_bias, gradient_scale)
329
+
330
+ # Resolve dtensors
331
+ w1 = _layers.mlp.resolve_dtensor(w1)
332
+ w2 = _layers.mlp.resolve_dtensor(w2)
333
+ w1_bias = _layers.mlp.resolve_dtensor(w1_bias)
334
+ w2_bias = _layers.mlp.resolve_dtensor(w2_bias)
335
+
336
+ # Forward pass
337
+ gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
338
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
339
+ gate = gate.clamp(min=None, max=limit)
340
+ up = up.clamp(min=-limit, max=limit)
341
+ glu = gate * torch.sigmoid(gate * alpha)
342
+ next_states = torch.bmm(((up + 1) * glu), w2)
343
+ next_states += w2_bias[..., None, :]
344
+ return next_states
345
+
346
+ # Shared expert MLP forward pass
347
+ def shared_mlp_forward(
348
+ x: torch.Tensor,
349
+ up_proj_weight: torch.Tensor,
350
+ down_proj_weight: torch.Tensor,
351
+ up_proj_bias: Optional[torch.Tensor] = None,
352
+ down_proj_bias: Optional[torch.Tensor] = None,
353
+ activation_fn: Optional[Any] = None,
354
+ gradient_scale: Optional[float] = None,
355
+ ) -> torch.Tensor:
356
+ # Default activation function
357
+ if activation_fn is None:
358
+ activation_fn = torch.nn.functional.gelu
359
+
360
+ # Scale weights
361
+ up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
362
+ down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
363
+ if up_proj_bias is not None:
364
+ up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
365
+ if down_proj_bias is not None:
366
+ down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
367
+
368
+ # Resolve dtensors
369
+ up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
370
+ down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
371
+ if up_proj_bias is not None:
372
+ up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
373
+ if down_proj_bias is not None:
374
+ down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
375
+
376
+ # Up projection
377
+ x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
378
+
379
+ # Activation
380
+ x = activation_fn(x)
381
+
382
+ # Down projection
383
+ x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
384
+
385
+ return x
386
+
387
+
388
+ # Combine outputs from shared expert and regular experts
389
+ def combine_expert_shared_outputs(
390
+ shared_expert_out: torch.Tensor,
391
+ expert_out: torch.Tensor,
392
+ shared_expert_weighted_sum: bool = False,
393
+ moe_top_k: int = 1,
394
+ ) -> torch.Tensor:
395
+ if shared_expert_weighted_sum:
396
+ # Weighted sum based on number of experts used
397
+ total_experts = moe_top_k + 1
398
+ shared_weight = 1.0 / total_experts
399
+ expert_weight = moe_top_k / total_experts
400
+ return shared_expert_out * shared_weight + expert_out * expert_weight
401
+ else:
402
+ # Simple addition
403
+ return shared_expert_out + expert_out
404
+
405
+
406
+ # Global variable to store load balancing loss
407
+ _LOAD_BALANCING_LOSS = []
408
+
409
+
410
+ def save_load_balancing_loss(loss):
411
+ global _LOAD_BALANCING_LOSS
412
+ _LOAD_BALANCING_LOSS.append(loss)
413
+
414
+
415
+ def get_load_balancing_loss():
416
+ global _LOAD_BALANCING_LOSS
417
+ return _LOAD_BALANCING_LOSS
418
+
419
+
420
+ def clear_load_balancing_loss():
421
+ global _LOAD_BALANCING_LOSS
422
+ _LOAD_BALANCING_LOSS.clear()
423
+
424
+
425
+ def batched_load_balancing_loss(args):
426
+ if args.moe_loss_weight == 0:
427
+ return 0.0
428
+
429
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
430
+ num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size
431
+ if args.num_layers_per_virtual_pipeline_stage is not None:
432
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
433
+
434
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
435
+ raise ValueError(
436
+ f"Expected {num_layers_per_pipeline_stage} token_per_experts "
437
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
438
+ f"{args.num_layers}\npipeline_model_parallel_size = "
439
+ f"{args.pipeline_model_parallel_size}\n"
440
+ "num_layers_per_virtual_pipeline_stage"
441
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
442
+ )
443
+ if len(expert_scores) != num_layers_per_pipeline_stage:
444
+ raise ValueError(
445
+ f"Expected {num_layers_per_pipeline_stage} expert_scores "
446
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
447
+ f"{args.num_layers}\npipeline_model_parallel_size = "
448
+ f"{args.pipeline_model_parallel_size}\n"
449
+ "num_layers_per_virtual_pipeline_stage"
450
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
451
+ )
452
+
453
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
454
+ assert all(
455
+ (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)
456
+ )
457
+
458
+ tokens = expert_scores[0].shape[0]
459
+ assert all(
460
+ (
461
+ (
462
+ x.ndim == 2
463
+ and x.shape[1] == args.moe_num_experts
464
+ and x.shape[0] == tokens
465
+ )
466
+ for x in expert_scores
467
+ )
468
+ )
469
+
470
+ # Concatenate the contributions of each layer and convert to
471
+ # the correct types and formats for the dot product.
472
+ expert_scores = torch.cat(expert_scores, dim=1)
473
+ if args.moe_lbl_in_fp32:
474
+ expert_scores = expert_scores.float()
475
+ if tokens != 0:
476
+ expert_scores = expert_scores.mean(dim=0)
477
+ else:
478
+ expert_scores = expert_scores.sum(dim=0)
479
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
480
+
481
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
482
+ assert tokens_per_expert.numel() == expected_values
483
+ assert expert_scores.numel() == expected_values
484
+
485
+ # Calculate the total scale across all factors.
486
+ #
487
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
488
+ scale_numerator = args.moe_num_experts * args.moe_loss_weight
489
+ scale_denominator = args.num_layers * tokens * args.moe_top_k
490
+ scale = scale_numerator / scale_denominator
491
+ return scale * torch.dot(tokens_per_expert, expert_scores)
492
+
493
+
494
+ # Calculate the expert capacity based on tokens, top_k, number of experts,
495
+ # expert parallel group, capacity factor, and whether expert model parallelism is used.
496
+ def expert_capacity(
497
+ tokens: int,
498
+ top_k: int,
499
+ num_experts: int,
500
+ expert_parallel_group: int,
501
+ moe_capacity_factor: float,
502
+ moe_expert_model_parallelism: bool,
503
+ ) -> int:
504
+ world_size = (
505
+ dist.get_world_size(expert_parallel_group)
506
+ if moe_expert_model_parallelism
507
+ else 1
508
+ )
509
+
510
+ tokens_per_expert = top_k * tokens * world_size / num_experts
511
+ return int(moe_capacity_factor * tokens_per_expert)
512
+
513
+
514
+ def load_balancing_loss(
515
+ tokens_per_expert: torch.Tensor,
516
+ expert_scores: torch.Tensor,
517
+ top_k: int,
518
+ num_experts: int,
519
+ ):
520
+ assert len(expert_scores.size()) == 2
521
+ tokens, num_experts = expert_scores.size()
522
+ assert num_experts == num_experts
523
+ assert len(tokens_per_expert.size()) == 1
524
+ (num_experts,) = tokens_per_expert.size()
525
+ assert num_experts == num_experts
526
+ scale = num_experts / (tokens * top_k)
527
+ return scale * torch.dot(
528
+ tokens_per_expert.to(expert_scores.dtype),
529
+ expert_scores.mean(dim=0),
530
+ )
531
+
532
+
533
+ def indices_and_bins(
534
+ top_expert: torch.Tensor,
535
+ sort_end_bit: int,
536
+ num_experts: int,
537
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
538
+ top_expert = top_expert.int()
539
+
540
+ # Ensure contiguous memory layout
541
+ top_expert = top_expert.contiguous()
542
+
543
+ # Ensure CUB knows which device to use
544
+ with torch.cuda.device(top_expert.device):
545
+ output = ops.sort(top_expert, sort_end_bit)
546
+ bin_ids, indices = output
547
+ tokens_per_expert = ops.histogram(top_expert, num_experts)
548
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
549
+
550
+ bins = bins.view(1) if not len(bins.size()) else bins
551
+ return indices, bin_ids, bins, tokens_per_expert
552
+
553
+
554
+ def expert_capacity_fn(
555
+ tokens: int,
556
+ top_k: int,
557
+ num_experts: int,
558
+ expert_parallel_group: torch.distributed.ProcessGroup,
559
+ moe_capacity_factor: float = 1.0,
560
+ moe_expert_model_parallelism: bool = False,
561
+ ) -> int:
562
+ world_size = (
563
+ dist.get_world_size(expert_parallel_group)
564
+ if moe_expert_model_parallelism
565
+ else 1
566
+ )
567
+ tokens_per_expert = top_k * tokens * world_size / num_experts
568
+ return int(moe_capacity_factor * tokens_per_expert)
569
+
570
+
571
+ def permute_and_compute(
572
+ x,
573
+ tokens_per_expert,
574
+ indices,
575
+ bin_ids,
576
+ expert_weights,
577
+ bins,
578
+ expert_capacity,
579
+ top_k,
580
+ w1,
581
+ w2,
582
+ w1_bias,
583
+ w2_bias,
584
+ gradient_scale,
585
+ alpha,
586
+ ):
587
+ # Route tokens to experts
588
+ x = x.view(-1, x.shape[-1])
589
+
590
+ # Ensure CUB knows which device to use
591
+ with torch.cuda.device(x.device):
592
+ x = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
593
+
594
+ # Expert computation
595
+ x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha)
596
+
597
+ # Ensure CUB knows which device to use
598
+ with torch.cuda.device(x.device):
599
+ # Route tokens back
600
+ out = ops.binned_scatter(x, indices, expert_weights, bins, top_k)
601
+ return out
602
+
603
+
604
+ def forward_once(
605
+ x: torch.Tensor,
606
+ expert_weights: torch.Tensor,
607
+ top_experts: torch.Tensor,
608
+ w1: torch.Tensor,
609
+ w2: torch.Tensor,
610
+ w1_bias: torch.Tensor,
611
+ w2_bias: torch.Tensor,
612
+ gradient_scale: Optional[float] = None,
613
+ alpha: float = 1.702,
614
+ sort_end_bit: int = 0,
615
+ top_k: int = 4,
616
+ num_experts: int = 128,
617
+ expert_parallel_group: int = None,
618
+ moe_capacity_factor: float = 1.0,
619
+ moe_expert_model_parallelism: bool = False,
620
+ mlp_impl: Optional[str] = None,
621
+ ):
622
+ # x: [sl, bs, hs]
623
+ # expert_weights: [sl * bs, top-k]
624
+ # top_experts: [sl * bs, top-k]
625
+ expert_weights = expert_weights.flatten()
626
+ top_experts = top_experts.flatten()
627
+
628
+ with torch.no_grad():
629
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
630
+ top_experts, sort_end_bit, num_experts
631
+ )
632
+
633
+ # Calculate expert capacity
634
+ sl, bs, _ = x.size()
635
+
636
+ expert_capacity = expert_capacity_fn(
637
+ sl * bs,
638
+ top_k,
639
+ num_experts,
640
+ expert_parallel_group,
641
+ moe_capacity_factor,
642
+ moe_expert_model_parallelism,
643
+ )
644
+
645
+ if expert_capacity == 0:
646
+ expert_capacity = torch.max(tokens_per_expert).item()
647
+
648
+ x = permute_and_compute(
649
+ x,
650
+ tokens_per_expert,
651
+ indices,
652
+ bin_ids,
653
+ expert_weights,
654
+ bins,
655
+ expert_capacity,
656
+ top_k,
657
+ w1,
658
+ w2,
659
+ w1_bias,
660
+ w2_bias,
661
+ gradient_scale,
662
+ alpha,
663
+ )
664
+ return x, tokens_per_expert
665
+
666
+
667
+ def parallel_forward_once(
668
+ x: torch.Tensor,
669
+ expert_weights: torch.Tensor,
670
+ top_experts: torch.Tensor,
671
+ w1: torch.Tensor,
672
+ w2: torch.Tensor,
673
+ w1_bias: torch.Tensor,
674
+ w2_bias: torch.Tensor,
675
+ gradient_scale: Optional[float] = None,
676
+ alpha: float = 1.702,
677
+ sort_end_bit: int = 0,
678
+ top_k: int = 4,
679
+ num_experts: int = 128,
680
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
681
+ moe_capacity_factor: float = 1.0,
682
+ moe_expert_model_parallelism: bool = True,
683
+ hidden_size: int = 1152,
684
+ mlp_impl: Optional[str] = "grouped",
685
+ ):
686
+ # Flatten inputs
687
+ expert_weights = expert_weights.flatten()
688
+ top_experts = top_experts.flatten()
689
+
690
+ # TODO: remove debugging var
691
+ # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0
692
+
693
+ with torch.no_grad():
694
+ # Step 1: Local permutation setup
695
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
696
+ top_experts, sort_end_bit, num_experts
697
+ )
698
+
699
+ # Calculate sharding parameters
700
+ world_size = dist.get_world_size(expert_parallel_group)
701
+ hidden_sharding_deg = hidden_sharding_degree(
702
+ world_size, num_experts, hidden_size
703
+ )
704
+ experts_per_rank_val = experts_per_rank(num_experts, world_size)
705
+
706
+ # Replicate token counts for hidden sharding
707
+ repeated_tokens_per_expert = ops.repeat(
708
+ tokens_per_expert, (hidden_sharding_deg,)
709
+ )
710
+
711
+ # Exchange token counts across devices
712
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
713
+
714
+ # Ensure CUB knows which device to use
715
+ tpe_handle = dist.all_to_all_single(
716
+ parallel_tokens_per_expert,
717
+ repeated_tokens_per_expert,
718
+ group=expert_parallel_group,
719
+ async_op=True,
720
+ )
721
+
722
+ # Step 2: Local permutation - group tokens by target device
723
+ x = x.view(-1, x.shape[-1]) # [sl * bs, hs]
724
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
725
+
726
+ # Step 3: Compute communication counts and exchange tokens
727
+ with torch.no_grad():
728
+ tpe_handle.wait()
729
+
730
+ # Reshape for per-device calculations
731
+ repeated_tokens_per_expert = repeated_tokens_per_expert.view(
732
+ world_size, experts_per_rank_val
733
+ )
734
+ parallel_tokens_per_expert = parallel_tokens_per_expert.view(
735
+ world_size, experts_per_rank_val
736
+ )
737
+
738
+ # Calculate send/recv counts
739
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist()
740
+ # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist()
741
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
742
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist()
743
+ tokens_received = sum(recv_counts)
744
+
745
+ # Replicate for hidden sharding
746
+ x = ops.repeat(x, (hidden_sharding_deg, 1))
747
+
748
+ # Cross-device token exchange
749
+ parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all(
750
+ x, recv_counts, send_counts, expert_parallel_group, async_op=True
751
+ )
752
+
753
+ with torch.no_grad():
754
+ # Step 4: Setup for local expert computation
755
+ replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0)
756
+ replicate_bins = (
757
+ replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
758
+ )
759
+
760
+ # Create expert indices for received tokens
761
+ parallel_top_expert = torch.remainder(
762
+ torch.arange(
763
+ num_experts * hidden_sharding_deg,
764
+ dtype=torch.int32,
765
+ device=indices.device,
766
+ ),
767
+ experts_per_rank_val,
768
+ )
769
+ parallel_top_expert = ops.replicate(
770
+ parallel_top_expert.unsqueeze(dim=0),
771
+ replicate_bins,
772
+ tokens_received,
773
+ ).flatten()
774
+
775
+ # Sort tokens by expert assignment
776
+ parallel_bin_ids, parallel_indices = ops.sort(
777
+ parallel_top_expert,
778
+ sort_end_bit,
779
+ )
780
+
781
+ # Calculate bins for local experts
782
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
783
+ dim=0, dtype=torch.int
784
+ )
785
+ parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
786
+ parallel_bins = (
787
+ parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
788
+ )
789
+
790
+ # Calculate expert capacity
791
+ expert_capacity = expert_capacity_fn(
792
+ tokens_received,
793
+ top_k,
794
+ experts_per_rank_val,
795
+ expert_parallel_group,
796
+ moe_capacity_factor,
797
+ moe_expert_model_parallelism,
798
+ )
799
+ if expert_capacity == 0:
800
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
801
+
802
+ # Locally permute the tokens and perform the expert computation.
803
+ # Block to make sure that the cross-device permutation is complete.
804
+ if mlp_impl == "grouped":
805
+ # GroupedMLP requires counts on CPU. We can use the tensor already
806
+ # moved to CPU for the prior all_to_all, which avoids an extra
807
+ # device synchronization.
808
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
809
+ dim=0,
810
+ dtype=torch.int,
811
+ )
812
+
813
+ # Step 5: Expert computation
814
+ parallel_x_handle.wait()
815
+
816
+ parallel_x = permute_and_compute(
817
+ parallel_x,
818
+ parallel_tokens_per_expert,
819
+ parallel_indices,
820
+ parallel_bin_ids,
821
+ None, # expert_weights
822
+ parallel_bins,
823
+ expert_capacity,
824
+ top_k=1,
825
+ w1=w1,
826
+ w2=w2,
827
+ w1_bias=w1_bias,
828
+ w2_bias=w2_bias,
829
+ gradient_scale=gradient_scale,
830
+ alpha=alpha,
831
+ )
832
+
833
+ # Step 6: Reverse communication - send results back
834
+ x, _ = _layers.all_to_all.all_to_all(
835
+ parallel_x, send_counts, recv_counts, expert_parallel_group
836
+ )
837
+
838
+ # Step 7: Reduce across hidden sharding dimension
839
+ shape = (hidden_sharding_deg, -1, hidden_size)
840
+ x = x.view(shape).sum(dim=0)
841
+
842
+ # Step 8: Final local unpermutation
843
+ x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
844
+
845
+ return x, tokens_per_expert.flatten()
846
+
847
+
848
+ def moe_forward(
849
+ x: torch.Tensor,
850
+ router_weight: torch.Tensor,
851
+ router_bias: Optional[torch.Tensor],
852
+ moe_top_k: int,
853
+ moe_num_experts: int,
854
+ moe_jitter_eps: float = None,
855
+ moe_normalize_expert_weights: int = None,
856
+ uniform_expert_assignment: bool = False,
857
+ training: bool = False,
858
+ w1: torch.Tensor = None,
859
+ w2: torch.Tensor = None,
860
+ w1_bias: torch.Tensor = None,
861
+ w2_bias: torch.Tensor = None,
862
+ gradient_scale: Optional[float] = None,
863
+ alpha: float = 1.702,
864
+ sort_end_bit: int = 0,
865
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
866
+ moe_capacity_factor: float = 1.0,
867
+ moe_expert_model_parallelism: bool = False,
868
+ forward_fn: Any = None,
869
+ hidden_size: int = None,
870
+ mlp_impl: str = "grouped",
871
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
872
+
873
+ # Route tokens to experts
874
+ logits, expert_weights, expert_indices = route_tokens(
875
+ x,
876
+ router_weight,
877
+ router_bias,
878
+ moe_top_k,
879
+ moe_num_experts,
880
+ moe_jitter_eps,
881
+ moe_normalize_expert_weights,
882
+ uniform_expert_assignment,
883
+ training,
884
+ )
885
+
886
+ # Create router scores for output
887
+ router_scores = (
888
+ torch.zeros_like(logits)
889
+ .scatter_(1, expert_indices, expert_weights)
890
+ .transpose(0, 1)
891
+ )
892
+
893
+ in_shape = x.size()
894
+
895
+ # Prepare forward function arguments
896
+ forward_args = {
897
+ "x": x,
898
+ "expert_weights": expert_weights,
899
+ "top_experts": expert_indices,
900
+ "w1": w1,
901
+ "w2": w2,
902
+ "w1_bias": w1_bias,
903
+ "w2_bias": w2_bias,
904
+ "gradient_scale": gradient_scale,
905
+ "alpha": alpha,
906
+ "sort_end_bit": sort_end_bit,
907
+ "top_k": moe_top_k,
908
+ "num_experts": moe_num_experts,
909
+ "expert_parallel_group": expert_parallel_group,
910
+ "moe_capacity_factor": moe_capacity_factor,
911
+ "moe_expert_model_parallelism": moe_expert_model_parallelism,
912
+ "mlp_impl": mlp_impl,
913
+ }
914
+
915
+ # Add hidden_size for parallel forward
916
+ if moe_expert_model_parallelism and hidden_size is not None:
917
+ forward_args["hidden_size"] = hidden_size
918
+ elif moe_expert_model_parallelism and hidden_size is None:
919
+ # Infer hidden_size from input shape
920
+ forward_args["hidden_size"] = x.shape[-1]
921
+
922
+ # Compute expert outputs
923
+ x, tokens_per_expert = forward_fn(**forward_args)
924
+
925
+ # Save load balancing loss if needed
926
+ moe_loss_weight = 0.0 # Can be made configurable
927
+ if training and moe_loss_weight > 0:
928
+ save_load_balancing_loss((tokens_per_expert, logits))
929
+
930
+ # Restore original shape
931
+ x = x.view(in_shape)
932
+
933
+ return x, expert_weights, router_scores
934
+
935
+
936
+ def moe_forward_with_shared_expert(
937
+ x: torch.Tensor,
938
+ router_weight: torch.Tensor,
939
+ router_bias: Optional[torch.Tensor],
940
+ moe_top_k: int,
941
+ moe_num_experts: int,
942
+ moe_jitter_eps: float = None,
943
+ moe_normalize_expert_weights: int = None,
944
+ uniform_expert_assignment: bool = False,
945
+ training: bool = False,
946
+ w1: torch.Tensor = None,
947
+ w2: torch.Tensor = None,
948
+ w1_bias: torch.Tensor = None,
949
+ w2_bias: torch.Tensor = None,
950
+ gradient_scale: Optional[float] = None,
951
+ alpha: float = 1.702,
952
+ sort_end_bit: int = 0,
953
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
954
+ moe_capacity_factor: float = 1.0,
955
+ moe_expert_model_parallelism: bool = False,
956
+ forward_fn: Any = None,
957
+ hidden_size: int = None,
958
+ mlp_impl: str = "grouped",
959
+ # Shared expert parameters
960
+ shared_up_proj_weight: Optional[torch.Tensor] = None,
961
+ shared_down_proj_weight: Optional[torch.Tensor] = None,
962
+ shared_up_proj_bias: Optional[torch.Tensor] = None,
963
+ shared_down_proj_bias: Optional[torch.Tensor] = None,
964
+ shared_expert_weighted_sum: bool = False,
965
+ shared_activation_fn: Optional[Any] = None,
966
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
967
+
968
+ # First, compute regular MoE forward pass
969
+ expert_out, expert_weights, router_scores = moe_forward(
970
+ x=x,
971
+ router_weight=router_weight,
972
+ router_bias=router_bias,
973
+ moe_top_k=moe_top_k,
974
+ moe_num_experts=moe_num_experts,
975
+ moe_jitter_eps=moe_jitter_eps,
976
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
977
+ uniform_expert_assignment=uniform_expert_assignment,
978
+ training=training,
979
+ w1=w1,
980
+ w2=w2,
981
+ w1_bias=w1_bias,
982
+ w2_bias=w2_bias,
983
+ gradient_scale=gradient_scale,
984
+ alpha=alpha,
985
+ sort_end_bit=sort_end_bit,
986
+ expert_parallel_group=expert_parallel_group,
987
+ moe_capacity_factor=moe_capacity_factor,
988
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
989
+ forward_fn=forward_fn,
990
+ hidden_size=hidden_size,
991
+ mlp_impl=mlp_impl,
992
+ )
993
+
994
+ # If shared expert weights provided, compute shared expert output
995
+ if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
996
+ shared_expert_out = shared_mlp_forward(
997
+ x=x,
998
+ up_proj_weight=shared_up_proj_weight,
999
+ down_proj_weight=shared_down_proj_weight,
1000
+ up_proj_bias=shared_up_proj_bias,
1001
+ down_proj_bias=shared_down_proj_bias,
1002
+ activation_fn=shared_activation_fn,
1003
+ gradient_scale=gradient_scale,
1004
+ )
1005
+
1006
+ # Combine expert outputs
1007
+ combined_out = combine_expert_shared_outputs(
1008
+ shared_expert_out=shared_expert_out,
1009
+ expert_out=expert_out,
1010
+ shared_expert_weighted_sum=shared_expert_weighted_sum,
1011
+ moe_top_k=moe_top_k,
1012
+ )
1013
+
1014
+ return combined_out, expert_weights, router_scores
1015
+
1016
+ # Return regular MoE output if no shared expert
1017
+ return expert_out, expert_weights, router_scores
1018
+
1019
+
1020
+ def create_shared_expert_weights(
1021
+ hidden_size: int,
1022
+ shared_expert_hidden_size: int,
1023
+ device: torch.device,
1024
+ dtype: torch.dtype,
1025
+ init_method: Any,
1026
+ output_layer_init_method: Any = None,
1027
+ ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
1028
+
1029
+ if output_layer_init_method is None:
1030
+ output_layer_init_method = init_method
1031
+
1032
+ # Create weight tensors
1033
+ up_proj_weight = torch.empty(
1034
+ shared_expert_hidden_size,
1035
+ hidden_size,
1036
+ device=device,
1037
+ dtype=dtype,
1038
+ )
1039
+ down_proj_weight = torch.empty(
1040
+ hidden_size,
1041
+ shared_expert_hidden_size,
1042
+ device=device,
1043
+ dtype=dtype,
1044
+ )
1045
+
1046
+ # Initialize weights
1047
+ init_method(up_proj_weight)
1048
+ output_layer_init_method(down_proj_weight)
1049
+
1050
+ # No bias by default
1051
+ return up_proj_weight, down_proj_weight, None, None
1052
+
1053
+
1054
+ # HACK: Extract device_mesh from pre-hook closure - required for transformers integration
1055
+ # This exists because device_mesh is trapped in hook closures with no model attribute
1056
+ # Fragile - breaks if hook structure changes or Python internals change
1057
+ # TODO: Replace with a more robust solution when available
1058
+ def get_device_mesh(model):
1059
+ # Extract device_mesh from child's unused pre_hook closure
1060
+ try:
1061
+ # Find the pre-hook that contains 'device_mesh' in its closure
1062
+ hook = next(
1063
+ h
1064
+ for h in model.experts._forward_pre_hooks.values()
1065
+ if "device_mesh" in h.__code__.co_freevars
1066
+ )
1067
+ # Extract the device_mesh from the closure
1068
+ return hook.__closure__[
1069
+ hook.__code__.co_freevars.index("device_mesh")
1070
+ ].cell_contents
1071
+ except Exception:
1072
+ return None
1073
+
1074
+
1075
+ class MegaBlocksMoeMLP(torch.nn.Module):
1076
+ can_torch_compile: bool = True
1077
+
1078
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1079
+ moe_top_k = getattr(self.router, "top_k", 4)
1080
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
1081
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
1082
+ alpha = getattr(self.experts, "alpha", 1.0)
1083
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
1084
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
1085
+ moe_normalize_expert_weights = getattr(
1086
+ self.experts, "normalize_expert_weights", None
1087
+ )
1088
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
1089
+
1090
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
1091
+ if expert_parallel_group is None:
1092
+ device_mesh = get_device_mesh(self)
1093
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
1094
+
1095
+ has_parallel = (
1096
+ expert_parallel_group is not None
1097
+ and dist.is_initialized()
1098
+ and dist.get_world_size(expert_parallel_group) > 1
1099
+ )
1100
+ forward_fn = parallel_forward_once if has_parallel else forward_once
1101
+
1102
+ sort_end_bit = max(
1103
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
1104
+ )
1105
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
1106
+ output, expert_weights_out, *_ = moe_forward(
1107
+ x=x,
1108
+ router_weight=self.router.weight,
1109
+ router_bias=self.router.bias,
1110
+ moe_top_k=moe_top_k,
1111
+ moe_num_experts=moe_num_experts,
1112
+ moe_jitter_eps=moe_jitter_eps,
1113
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
1114
+ uniform_expert_assignment=uniform_expert_assignment,
1115
+ training=self.training,
1116
+ w1=self.experts.gate_up_proj,
1117
+ w2=self.experts.down_proj,
1118
+ w1_bias=self.experts.gate_up_proj_bias,
1119
+ w2_bias=self.experts.down_proj_bias,
1120
+ gradient_scale=gradient_scale,
1121
+ alpha=alpha,
1122
+ sort_end_bit=sort_end_bit,
1123
+ expert_parallel_group=expert_parallel_group,
1124
+ moe_capacity_factor=moe_capacity_factor,
1125
+ moe_expert_model_parallelism=has_parallel,
1126
+ forward_fn=forward_fn,
1127
+ hidden_size=self.experts.hidden_size,
1128
+ mlp_impl=mlp_impl,
1129
+ )
1130
+ return output, expert_weights_out
1131
+
1132
+
1133
+ # Export main classes
1134
+ __all__ = ["MegaBlocksMoeMLP", "MegaBlocksMoeMLPWithSharedExpert"]
1135
+
1136
+
1137
+ class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
1138
+
1139
+ def __init__(self):
1140
+ super().__init__()
1141
+ # Shared expert weights will be set by the user
1142
+ self.shared_up_proj_weight = None
1143
+ self.shared_down_proj_weight = None
1144
+ self.shared_up_proj_bias = None
1145
+ self.shared_down_proj_bias = None
1146
+ self.shared_expert_weighted_sum = False
1147
+ self.shared_activation_fn = None
1148
+
1149
+ def set_shared_expert_weights(
1150
+ self,
1151
+ up_proj_weight: torch.Tensor,
1152
+ down_proj_weight: torch.Tensor,
1153
+ up_proj_bias: Optional[torch.Tensor] = None,
1154
+ down_proj_bias: Optional[torch.Tensor] = None,
1155
+ weighted_sum: bool = False,
1156
+ activation_fn: Optional[Any] = None,
1157
+ ):
1158
+ self.shared_up_proj_weight = up_proj_weight
1159
+ self.shared_down_proj_weight = down_proj_weight
1160
+ self.shared_up_proj_bias = up_proj_bias
1161
+ self.shared_down_proj_bias = down_proj_bias
1162
+ self.shared_expert_weighted_sum = weighted_sum
1163
+ self.shared_activation_fn = activation_fn
1164
+
1165
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1166
+ moe_top_k = getattr(self.router, "top_k", 4)
1167
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
1168
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
1169
+ alpha = getattr(self.experts, "alpha", 1.0)
1170
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
1171
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
1172
+ moe_normalize_expert_weights = getattr(
1173
+ self.experts, "normalize_expert_weights", None
1174
+ )
1175
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
1176
+
1177
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
1178
+ if expert_parallel_group is None:
1179
+ device_mesh = get_device_mesh(self)
1180
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
1181
+
1182
+ has_parallel = (
1183
+ expert_parallel_group is not None
1184
+ and dist.is_initialized()
1185
+ and dist.get_world_size(expert_parallel_group) > 1
1186
+ )
1187
+ forward_fn = parallel_forward_once if has_parallel else forward_once
1188
+
1189
+ sort_end_bit = max(
1190
+ int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1
1191
+ )
1192
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
1193
+
1194
+ output, expert_weights_out, *_ = moe_forward_with_shared_expert(
1195
+ x=x,
1196
+ router_weight=self.router.weight,
1197
+ router_bias=self.router.bias,
1198
+ moe_top_k=moe_top_k,
1199
+ moe_num_experts=moe_num_experts,
1200
+ moe_jitter_eps=moe_jitter_eps,
1201
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
1202
+ uniform_expert_assignment=uniform_expert_assignment,
1203
+ training=self.training,
1204
+ w1=self.experts.gate_up_proj,
1205
+ w2=self.experts.down_proj,
1206
+ w1_bias=self.experts.gate_up_proj_bias,
1207
+ w2_bias=self.experts.down_proj_bias,
1208
+ gradient_scale=gradient_scale,
1209
+ alpha=alpha,
1210
+ sort_end_bit=sort_end_bit,
1211
+ expert_parallel_group=expert_parallel_group,
1212
+ moe_capacity_factor=moe_capacity_factor,
1213
+ moe_expert_model_parallelism=has_parallel,
1214
+ forward_fn=forward_fn,
1215
+ hidden_size=self.experts.hidden_size,
1216
+ mlp_impl=mlp_impl,
1217
+ # Shared expert parameters
1218
+ shared_up_proj_weight=self.shared_up_proj_weight,
1219
+ shared_down_proj_weight=self.shared_down_proj_weight,
1220
+ shared_up_proj_bias=self.shared_up_proj_bias,
1221
+ shared_down_proj_bias=self.shared_down_proj_bias,
1222
+ shared_expert_weighted_sum=self.shared_expert_weighted_sum,
1223
+ shared_activation_fn=self.shared_activation_fn,
1224
+ )
1225
+ return output, expert_weights_out
build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/__init__.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from .binned_gather import binned_gather
5
+ from .binned_scatter import binned_scatter
6
+ from .cumsum import exclusive_cumsum, inclusive_cumsum
7
+ from .gather import gather
8
+ from .histogram import histogram
9
+ from .padded_gather import padded_gather
10
+ from .padded_scatter import padded_scatter
11
+ from .repeat import repeat
12
+ from .replicate import replicate
13
+ from .round_up import round_up
14
+ from .scatter import scatter
15
+ from .sort import sort
16
+ from .sum import sum
17
+ from .topology import topology
18
+
19
+ __all__ = [
20
+ 'binned_gather',
21
+ 'binned_scatter',
22
+ 'exclusive_cumsum',
23
+ 'inclusive_cumsum',
24
+ 'gather',
25
+ 'histogram',
26
+ 'padded_gather',
27
+ 'padded_scatter',
28
+ 'repeat',
29
+ 'replicate',
30
+ 'round_up',
31
+ 'scatter',
32
+ 'sort',
33
+ 'sum',
34
+ 'topology',
35
+ ]
build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/all_to_all_benchmark.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+
7
+ # from megablocks import benchmark_util
8
+ # from megablocks.layers.all_to_all import all_to_all
9
+
10
+ from .. import benchmark_util
11
+ from .._layers.all_to_all import all_to_all
12
+
13
+ _ALL_TO_ALL_BENCHMARK = (
14
+ (8, 1024),
15
+ (16, 1024),
16
+ (32, 1024),
17
+ (64, 1024),
18
+ (128, 1024),
19
+ (256, 1024),
20
+ (512, 1024),
21
+ (1024, 1024),
22
+ (2 * 1024, 1024),
23
+ (4 * 1024, 1024),
24
+ (8 * 1024, 1024),
25
+ (16 * 1024, 1024),
26
+ (32 * 1024, 1024),
27
+ (64 * 1024, 1024),
28
+ (128 * 1024, 1024),
29
+ (256 * 1024, 1024),
30
+ (512 * 1024, 1024),
31
+ (1024 * 1024, 1024),
32
+ )
33
+
34
+
35
+ def benchmark_all_to_all(group, sl, hs):
36
+ world_size = dist.get_world_size(group)
37
+ assert (sl % world_size) == 0
38
+ send_recv_sizes = [sl // world_size] * world_size
39
+
40
+ x = torch.randn((sl, hs)).cuda().half()
41
+
42
+ details = {
43
+ 'world_size': world_size,
44
+ 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements.
45
+ }
46
+
47
+ def benchmark():
48
+ return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
49
+
50
+ time, std = benchmark_util.benchmark_function(benchmark)
51
+
52
+ if dist.get_rank(group) == 0:
53
+ benchmark_util.log_benchmark('All-To-All', details, time, std)
54
+
55
+
56
+ if __name__ == '__main__':
57
+ assert dist.is_available()
58
+ group = dist.init_process_group(backend='nccl')
59
+ local_rank = dist.get_rank(group)
60
+ torch.cuda.set_device(local_rank)
61
+
62
+ for args in _ALL_TO_ALL_BENCHMARK:
63
+ benchmark_all_to_all(group, *args)
build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/all_to_all_benchmark.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ DISTRIBUTED_ARGUMENTS="\
4
+ --nproc_per_node 8 \
5
+ --nnodes 1 \
6
+ --node_rank 0 \
7
+ --master_addr localhost \
8
+ --master_port 6000"
9
+
10
+ python -m torch.distributed.launch \
11
+ ${DISTRIBUTED_ARGUMENTS} \
12
+ megablocks/ops/all_to_all_benchmark.py
build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/binned_gather.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from typing import Any
4
+
5
+ import torch
6
+ from .stk_autocast import custom_bwd, custom_fwd
7
+
8
+ from ..backend import kernels
9
+
10
+
11
+ # Autograd wrapper for binned_gather kernel.
12
+ class BinnedGatherOp(torch.autograd.Function):
13
+
14
+ @staticmethod
15
+ @custom_fwd
16
+ def forward(
17
+ ctx: Any,
18
+ x: torch.Tensor,
19
+ indices: torch.Tensor,
20
+ bins: torch.Tensor,
21
+ bin_size: int,
22
+ top_k: int,
23
+ ):
24
+ ctx.save_for_backward(indices, bins)
25
+ ctx.top_k = top_k
26
+ return kernels.binned_gather(x, indices, None, bins, bin_size, top_k)
27
+
28
+ @staticmethod
29
+ @custom_bwd
30
+ def backward(ctx: Any, grad: torch.Tensor):
31
+ grad = grad.contiguous()
32
+ indices, bins = ctx.saved_tensors
33
+ out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k)
34
+ return out, None, None, None, None
35
+
36
+
37
+ binned_gather = BinnedGatherOp.apply
build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/binned_scatter.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from typing import Any
4
+
5
+ import torch
6
+ from .stk_autocast import custom_bwd, custom_fwd
7
+
8
+ from ..backend import kernels
9
+
10
+
11
+ # Autograd wrapper for binned_scatter kernel.
12
+ class BinnedScatterOp(torch.autograd.Function):
13
+
14
+ @staticmethod
15
+ @custom_fwd
16
+ def forward(
17
+ ctx: Any,
18
+ x: torch.Tensor,
19
+ indices: torch.Tensor,
20
+ weights: torch.Tensor,
21
+ bins: torch.Tensor,
22
+ top_k: int,
23
+ ):
24
+ assert len(x.size()) == 3
25
+ ctx.bin_size = x.size(1)
26
+ ctx.top_k = top_k
27
+
28
+ # TODO(tgale): Don't save 'x' for backwards if we don't need to
29
+ # calculate the gradient w.r.t. 'weights'.
30
+ ctx.save_for_backward(x, indices, weights, bins)
31
+ return kernels.binned_scatter(x, indices, weights, bins, top_k)
32
+
33
+ @staticmethod
34
+ @custom_bwd
35
+ def backward(ctx: Any, grad: torch.Tensor):
36
+ grad = grad.contiguous()
37
+ x, indices, weights, bins = ctx.saved_tensors
38
+ out = kernels.binned_gather(
39
+ grad,
40
+ indices,
41
+ weights,
42
+ bins,
43
+ ctx.bin_size,
44
+ ctx.top_k,
45
+ )
46
+
47
+ wgrad = None
48
+ if ctx.needs_input_grad[2]:
49
+ wgrad = kernels.binned_scatter_wgrad(
50
+ x,
51
+ grad,
52
+ indices,
53
+ bins,
54
+ ctx.top_k,
55
+ )
56
+ return out, None, wgrad, None, None
57
+
58
+
59
+ binned_scatter = BinnedScatterOp.apply
build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/cumsum.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Any
5
+
6
+ # NOTE: Torch needs to be imported before the custom
7
+ # extensions. Otherwise libc10.so cannot be found.
8
+ import torch
9
+
10
+ # Wrap this in a try-block with better error message and
11
+ # instructions for building the c++ operations.
12
+ try:
13
+ # import megablocks_ops as ops # type: ignore
14
+ from .._ops import ops # type: ignore
15
+ except ModuleNotFoundError as e:
16
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
17
+
18
+
19
+ # Autograd wrappers for cumsum kernels.
20
+ # NOTE: Does not support gradients.
21
+ class ExclusiveCumsumOp(torch.autograd.Function):
22
+
23
+ @staticmethod
24
+ def forward(ctx: Any, x: torch.Tensor, dim: int):
25
+ if len(x.size()) == 1:
26
+ x = x.view([1, -1])
27
+ out = torch.empty_like(x)
28
+ ops.exclusive_cumsum(x, 1, out)
29
+ return out.squeeze()
30
+ out = torch.empty_like(x)
31
+ ops.exclusive_cumsum(x, dim, out)
32
+ return out
33
+
34
+
35
+ exclusive_cumsum = ExclusiveCumsumOp.apply
36
+
37
+
38
+ class InclusiveCumsumOp(torch.autograd.Function):
39
+
40
+ @staticmethod
41
+ def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor:
42
+ if len(x.size()) == 1:
43
+ x = x.view([1, -1])
44
+ out = torch.empty_like(x)
45
+ ops.inclusive_cumsum(x, 1, out)
46
+ return out.squeeze()
47
+ out = torch.empty_like(x)
48
+ ops.inclusive_cumsum(x, dim, out)
49
+ return out
50
+
51
+
52
+ inclusive_cumsum = InclusiveCumsumOp.apply
build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/gather.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from typing import Any
4
+
5
+ import torch
6
+ from .stk_autocast import custom_bwd, custom_fwd
7
+
8
+ from ..backend import kernels
9
+
10
+
11
+ # Autograd wrapper for gather kernel.
12
+ class GatherOp(torch.autograd.Function):
13
+
14
+ @staticmethod
15
+ @custom_fwd
16
+ def forward(
17
+ ctx: Any,
18
+ x: torch.Tensor,
19
+ indices: torch.Tensor,
20
+ bin_ids: torch.Tensor,
21
+ bins: torch.Tensor,
22
+ top_k: int,
23
+ ):
24
+ ctx.save_for_backward(indices, bin_ids, bins)
25
+ ctx.top_k = top_k
26
+ return kernels.gather(x, indices, bin_ids, None, bins, top_k)
27
+
28
+ @staticmethod
29
+ @custom_bwd
30
+ def backward(ctx: Any, grad: torch.Tensor):
31
+ grad = grad.contiguous()
32
+
33
+ indices, bin_ids, bins = ctx.saved_tensors
34
+ out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k)
35
+ return out, None, None, None, None, None
36
+
37
+
38
+ gather = GatherOp.apply
build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/histogram.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Any
5
+
6
+ # NOTE: Torch needs to be imported before the custom
7
+ # extensions. Otherwise libc10.so cannot be found.
8
+ import torch
9
+
10
+ # Wrap this in a try-block with better error message and
11
+ # instructions for building the c++ operations.
12
+ try:
13
+ from .._ops import ops # type: ignore
14
+ except ModuleNotFoundError as e:
15
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
16
+
17
+
18
+ # Autograd wrapper for histogram kernel.
19
+ # NOTE: Does not support gradients.
20
+ class HistogramOp(torch.autograd.Function):
21
+
22
+ @staticmethod
23
+ def forward(ctx: Any, x: torch.Tensor, max_val: float):
24
+ return ops.histogram(x, max_val)
25
+
26
+
27
+ histogram = HistogramOp.apply
build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/histogram_benchmark.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import unittest
5
+
6
+ import numpy as np
7
+ import torch
8
+ from absl.testing import parameterized
9
+
10
+ from .. import ops
11
+
12
+ _HISTOGRAM_TESTS = (
13
+ (16384, torch.int32, 2),
14
+ (16384, torch.int32, 4),
15
+ (16384, torch.int32, 8),
16
+ (16384, torch.int32, 16),
17
+ (16384, torch.int32, 32),
18
+ (16384, torch.int32, 64),
19
+ (16384, torch.int32, 128),
20
+ (16384, torch.int32, 256),
21
+ )
22
+
23
+
24
+ def benchmark_function(fn, iterations=10):
25
+ # Run once to get rid of startup overhead.
26
+ fn()
27
+ times = []
28
+ for _ in range(iterations):
29
+ start = torch.cuda.Event(enable_timing=True)
30
+ end = torch.cuda.Event(enable_timing=True)
31
+ start.record()
32
+ fn()
33
+ end.record()
34
+ torch.cuda.synchronize()
35
+ times.append(start.elapsed_time(end))
36
+ times = np.array(times)
37
+ return times.mean(), times.std(), times.max(), times.min()
38
+
39
+
40
+ def log_benchmark(arguments, mean_t, std_t):
41
+ print('=' * 60)
42
+ print('Benchmark Parameters:')
43
+ for (key, value) in arguments.items():
44
+ print(f'{key} = {value}')
45
+ print('Results:')
46
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
47
+ print('=' * 60)
48
+
49
+
50
+ class HistogramBenchmark(parameterized.TestCase):
51
+
52
+ @parameterized.parameters(*_HISTOGRAM_TESTS)
53
+ def testHistogram(self, n, dtype, max_val):
54
+ x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
55
+
56
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),)
57
+ arguments = {
58
+ 'n': n,
59
+ 'dtype': dtype,
60
+ 'max_val': max_val,
61
+ }
62
+ log_benchmark(arguments, mean_t, std_t)
63
+
64
+ @parameterized.parameters(*_HISTOGRAM_TESTS)
65
+ def testTorchHistogram(self, n, dtype, max_val):
66
+ x = torch.randint(0, 128, (n,)).cuda().to(dtype)
67
+
68
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),)
69
+ arguments = {
70
+ 'n': n,
71
+ 'dtype': dtype,
72
+ 'max_val': max_val,
73
+ }
74
+ log_benchmark(arguments, mean_t, std_t)
75
+
76
+
77
+ if __name__ == '__main__':
78
+ unittest.main()
build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/matmul_benchmark.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import unittest
5
+
6
+
7
+ # import stk
8
+
9
+ # try:
10
+ # import stk
11
+ # except ImportError:
12
+ # import warnings
13
+ # warnings.warn(
14
+ # 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.',
15
+ # )
16
+
17
+ from .. import stk
18
+
19
+ import torch
20
+ from absl.testing import parameterized
21
+
22
+ from .. import benchmark_util, ops
23
+
24
+
25
+ # Calling tensor.t() calls tensor.transpose(0, 1) which calls
26
+ # torch.as_strided(...). Circumvent this chain to avoid an overhead
27
+ # this adds.
28
+ def transpose_view(x):
29
+ return torch.as_strided(
30
+ x,
31
+ (x.shape[1], x.shape[0]),
32
+ (x.stride()[1], x.stride()[0]),
33
+ )
34
+
35
+
36
+ _MATMUL_TESTS = (
37
+ (64 * 1024, 512, 2048, 64),
38
+ (32 * 1024, 768, 3072, 64),
39
+ (8 * 1024, 1024, 4096, 64),
40
+ (4 * 2048, 4096, 4 * 4096, 4),
41
+ )
42
+
43
+
44
+ def log_benchmark(name, arguments, time, std, flops):
45
+ benchmark_util.log_benchmark(name, arguments, time, std)
46
+ print('flops = {:.2f}B'.format(flops / 1e9))
47
+ print('throughput = {:.2f}T'.format(flops / 1e9 / time))
48
+ print('=' * 60)
49
+
50
+
51
+ class MatmulBenchmark(parameterized.TestCase):
52
+
53
+ def build_sparse_matrix(self, x, padded_bins, fhs, ne):
54
+ blocking = 128
55
+ padded_tokens, _ = x.size()
56
+ assert padded_tokens % blocking == 0
57
+ assert fhs % blocking == 0
58
+
59
+ # Offsets for the sparse matrix. All rows have the
60
+ # same number of nonzero blocks dictated by the
61
+ # dimensionality of a single expert.
62
+ block_rows = padded_tokens // blocking
63
+ blocks_per_row = fhs // blocking
64
+ offsets = torch.arange(
65
+ 0,
66
+ block_rows * blocks_per_row + 1,
67
+ blocks_per_row,
68
+ dtype=torch.int32,
69
+ device=x.device,
70
+ )
71
+
72
+ # Indices for the sparse matrix. The indices for
73
+ # the intermediate matrix are dynamic depending
74
+ # on the mapping of tokens to experts.
75
+ column_indices = ops.topology(
76
+ padded_bins,
77
+ blocking,
78
+ block_rows,
79
+ blocks_per_row,
80
+ )
81
+ data = torch.empty(
82
+ column_indices.numel(),
83
+ blocking,
84
+ blocking,
85
+ dtype=torch.float16,
86
+ device=x.device,
87
+ )
88
+ shape = (padded_tokens, fhs * ne)
89
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
90
+ return stk.Matrix(shape, data, row_indices, column_indices, offsets)
91
+
92
+ def build_input_matrix(self, sl, hs, ne):
93
+ x = torch.randn((sl, hs)).cuda().half()
94
+
95
+ # Assign tokens to experts uniformly.
96
+ top_expert = torch.arange(0, sl).cuda().int() % ne
97
+
98
+ bin_ids, indices = ops.sort(top_expert)
99
+ tokens_per_expert = ops.histogram(top_expert, ne)
100
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
101
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
102
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
103
+ out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1)
104
+ return out, padded_bins
105
+
106
+ def build_weight_matrix(self, ne, hs, fhs):
107
+ return torch.randn((hs, ne * fhs)).cuda().half()
108
+
109
+ @parameterized.parameters(*_MATMUL_TESTS)
110
+ def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne):
111
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
112
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
113
+ topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
114
+ w = transpose_view(w)
115
+
116
+ def benchmark():
117
+ return stk.ops.sdd(x, w, topo)
118
+
119
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
120
+ arguments = {
121
+ 'sequence_length': sl,
122
+ 'hidden_size': hs,
123
+ 'ffn_hidden_size': fhs,
124
+ 'num_experts': ne,
125
+ }
126
+ log_benchmark(
127
+ '0::Fwd::SDD::NT',
128
+ arguments,
129
+ mean_t,
130
+ std_t,
131
+ x.numel() * fhs * 2,
132
+ )
133
+
134
+ @parameterized.parameters(*_MATMUL_TESTS)
135
+ def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne):
136
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
137
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
138
+ topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
139
+
140
+ def benchmark():
141
+ return stk.ops.dsd(topo, w)
142
+
143
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
144
+ arguments = {
145
+ 'sequence_length': sl,
146
+ 'hidden_size': hs,
147
+ 'ffn_hidden_size': fhs,
148
+ 'num_experts': ne,
149
+ }
150
+ log_benchmark(
151
+ '0::GradX::DSD::NN',
152
+ arguments,
153
+ mean_t,
154
+ std_t,
155
+ x.numel() * fhs * 2,
156
+ )
157
+
158
+ @parameterized.parameters(*_MATMUL_TESTS)
159
+ def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne):
160
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
161
+ topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
162
+ topo = topo.t()
163
+
164
+ def benchmark():
165
+ return stk.ops.dsd(topo, x)
166
+
167
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
168
+ arguments = {
169
+ 'sequence_length': sl,
170
+ 'hidden_size': hs,
171
+ 'ffn_hidden_size': fhs,
172
+ 'num_experts': ne,
173
+ }
174
+ log_benchmark(
175
+ '0::GradW::DSD::TN',
176
+ arguments,
177
+ mean_t,
178
+ std_t,
179
+ x.numel() * fhs * 2,
180
+ )
181
+
182
+ @parameterized.parameters(*_MATMUL_TESTS)
183
+ def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne):
184
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
185
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
186
+ x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
187
+
188
+ def benchmark():
189
+ return stk.ops.dsd(x, w)
190
+
191
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
192
+ arguments = {
193
+ 'sequence_length': sl,
194
+ 'hidden_size': hs,
195
+ 'ffn_hidden_size': fhs,
196
+ 'num_experts': ne,
197
+ }
198
+ log_benchmark(
199
+ '1::Fwd::DSD::NN',
200
+ arguments,
201
+ mean_t,
202
+ std_t,
203
+ x.nnz * hs * 2,
204
+ )
205
+
206
+ @parameterized.parameters(*_MATMUL_TESTS)
207
+ def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne):
208
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
209
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
210
+ x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
211
+ out = stk.ops.dsd(x, w)
212
+ w = transpose_view(w)
213
+
214
+ def benchmark():
215
+ return stk.ops.sdd(out, w, x)
216
+
217
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
218
+ arguments = {
219
+ 'sequence_length': sl,
220
+ 'hidden_size': hs,
221
+ 'ffn_hidden_size': fhs,
222
+ 'num_experts': ne,
223
+ }
224
+ log_benchmark(
225
+ '1::GradX::SDD::NT',
226
+ arguments,
227
+ mean_t,
228
+ std_t,
229
+ x.nnz * hs * 2,
230
+ )
231
+
232
+ @parameterized.parameters(*_MATMUL_TESTS)
233
+ def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne):
234
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
235
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
236
+ x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
237
+ out = stk.ops.dsd(x, w)
238
+ x = x.t()
239
+
240
+ def benchmark():
241
+ return stk.ops.dsd(x, out)
242
+
243
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
244
+ arguments = {
245
+ 'sequence_length': sl,
246
+ 'hidden_size': hs,
247
+ 'ffn_hidden_size': fhs,
248
+ 'num_experts': ne,
249
+ }
250
+ log_benchmark(
251
+ '1::GradW::DSD::TN',
252
+ arguments,
253
+ mean_t,
254
+ std_t,
255
+ x.nnz * hs * 2,
256
+ )
257
+
258
+ @parameterized.parameters(*_MATMUL_TESTS)
259
+ def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne):
260
+ assert (sl % ne) == 0
261
+ x = torch.randn((ne, sl // ne, hs)).cuda().half()
262
+ w = torch.randn((ne, hs, fhs)).cuda().half()
263
+
264
+ w = w.transpose(1, 2).contiguous()
265
+ w = w.transpose(1, 2)
266
+
267
+ def benchmark():
268
+ return torch.bmm(x, w)
269
+
270
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
271
+ arguments = {
272
+ 'sequence_length': sl,
273
+ 'hidden_size': hs,
274
+ 'ffn_hidden_size': fhs,
275
+ 'num_experts': ne,
276
+ }
277
+ log_benchmark(
278
+ '0::Fwd:DDD::NT',
279
+ arguments,
280
+ mean_t,
281
+ std_t,
282
+ x.numel() * fhs * 2,
283
+ )
284
+
285
+ @parameterized.parameters(*_MATMUL_TESTS)
286
+ def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne):
287
+ assert (sl % ne) == 0
288
+ x = torch.randn((ne, sl // ne, hs)).cuda().half()
289
+ w = torch.randn((ne, hs, fhs)).cuda().half()
290
+ out = torch.bmm(x, w)
291
+ w = w.transpose(1, 2).contiguous()
292
+
293
+ def benchmark():
294
+ return torch.bmm(out, w)
295
+
296
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
297
+ arguments = {
298
+ 'sequence_length': sl,
299
+ 'hidden_size': hs,
300
+ 'ffn_hidden_size': fhs,
301
+ 'num_experts': ne,
302
+ }
303
+ log_benchmark(
304
+ '0:GradX:DDD::NN',
305
+ arguments,
306
+ mean_t,
307
+ std_t,
308
+ x.numel() * fhs * 2,
309
+ )
310
+
311
+ @parameterized.parameters(*_MATMUL_TESTS)
312
+ def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne):
313
+ assert (sl % ne) == 0
314
+ x = torch.randn((ne, sl // ne, hs)).cuda().half()
315
+ w = torch.randn((ne, hs, fhs)).cuda().half()
316
+ out = torch.bmm(x, w)
317
+ out = out.transpose(1, 2)
318
+
319
+ def benchmark():
320
+ return torch.bmm(out, x)
321
+
322
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
323
+ arguments = {
324
+ 'sequence_length': sl,
325
+ 'hidden_size': hs,
326
+ 'ffn_hidden_size': fhs,
327
+ 'num_experts': ne,
328
+ }
329
+ log_benchmark(
330
+ '0:GradW:DDD::TN',
331
+ arguments,
332
+ mean_t,
333
+ std_t,
334
+ x.numel() * fhs * 2,
335
+ )
336
+
337
+ @parameterized.parameters(*_MATMUL_TESTS)
338
+ def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne):
339
+ assert (sl % ne) == 0
340
+ x = torch.randn((ne, sl // ne, fhs)).cuda().half()
341
+ w = torch.randn((ne, fhs, hs)).cuda().half()
342
+
343
+ def benchmark():
344
+ return torch.bmm(x, w)
345
+
346
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
347
+ arguments = {
348
+ 'sequence_length': sl,
349
+ 'hidden_size': hs,
350
+ 'ffn_hidden_size': fhs,
351
+ 'num_experts': ne,
352
+ }
353
+ log_benchmark(
354
+ '1::Fwd::DDD::NN',
355
+ arguments,
356
+ mean_t,
357
+ std_t,
358
+ x.numel() * hs * 2,
359
+ )
360
+
361
+ @parameterized.parameters(*_MATMUL_TESTS)
362
+ def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne):
363
+ assert (sl % ne) == 0
364
+ x = torch.randn((ne, sl // ne, fhs)).cuda().half()
365
+ w = torch.randn((ne, fhs, hs)).cuda().half()
366
+ out = torch.bmm(x, w)
367
+ w = torch.transpose(w, 1, 2)
368
+
369
+ def benchmark():
370
+ return torch.bmm(out, w)
371
+
372
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
373
+ arguments = {
374
+ 'sequence_length': sl,
375
+ 'hidden_size': hs,
376
+ 'ffn_hidden_size': fhs,
377
+ 'num_experts': ne,
378
+ }
379
+ log_benchmark(
380
+ '1::GradX::DDD::NT',
381
+ arguments,
382
+ mean_t,
383
+ std_t,
384
+ x.numel() * hs * 2,
385
+ )
386
+
387
+ @parameterized.parameters(*_MATMUL_TESTS)
388
+ def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne):
389
+ assert (sl % ne) == 0
390
+ x = torch.randn((ne, sl // ne, fhs)).cuda().half()
391
+ w = torch.randn((ne, fhs, hs)).cuda().half()
392
+ out = torch.bmm(x, w)
393
+ x = torch.transpose(x, 1, 2)
394
+
395
+ def benchmark():
396
+ return torch.bmm(x, out)
397
+
398
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
399
+ arguments = {
400
+ 'sequence_length': sl,
401
+ 'hidden_size': hs,
402
+ 'ffn_hidden_size': fhs,
403
+ 'num_experts': ne,
404
+ }
405
+ log_benchmark(
406
+ '1::GradW::DDD::TN',
407
+ arguments,
408
+ mean_t,
409
+ std_t,
410
+ x.numel() * hs * 2,
411
+ )
412
+
413
+
414
+ if __name__ == '__main__':
415
+ unittest.main()
build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/padded_gather.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from typing import Any
4
+
5
+ import torch
6
+ from .stk_autocast import custom_bwd, custom_fwd
7
+
8
+ from ..backend import kernels
9
+
10
+
11
+ # Autograd wrapper for padded_gather kernel.
12
+ class PaddedGatherOp(torch.autograd.Function):
13
+
14
+ @staticmethod
15
+ @custom_fwd
16
+ def forward(
17
+ ctx: Any,
18
+ x: torch.Tensor,
19
+ indices: torch.Tensor,
20
+ bin_ids: torch.Tensor,
21
+ bins: torch.Tensor,
22
+ padded_bins: torch.Tensor,
23
+ top_k: int,
24
+ ):
25
+ ctx.save_for_backward(indices, bin_ids, bins, padded_bins)
26
+ ctx.top_k = top_k
27
+ return kernels.padded_gather(
28
+ x,
29
+ indices,
30
+ bin_ids,
31
+ None,
32
+ bins,
33
+ padded_bins,
34
+ top_k,
35
+ )
36
+
37
+ @staticmethod
38
+ @custom_bwd
39
+ def backward(ctx: Any, grad: torch.Tensor):
40
+ grad = grad.contiguous()
41
+
42
+ indices, bin_ids, bins, padded_bins = ctx.saved_tensors
43
+ out = kernels.padded_scatter(
44
+ grad,
45
+ indices,
46
+ bin_ids,
47
+ None,
48
+ bins,
49
+ padded_bins,
50
+ ctx.top_k,
51
+ )
52
+ return out, None, None, None, None, None
53
+
54
+
55
+ padded_gather = PaddedGatherOp.apply
build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/padded_scatter.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from typing import Any
4
+
5
+ import torch
6
+ from .stk_autocast import custom_bwd, custom_fwd
7
+
8
+ from ..backend import kernels
9
+
10
+
11
+ # Autograd wrapper for padded_scatter kernel.
12
+ class PaddedScatterOp(torch.autograd.Function):
13
+
14
+ @staticmethod
15
+ @custom_fwd
16
+ def forward(
17
+ ctx: Any,
18
+ x: torch.Tensor,
19
+ indices: torch.Tensor,
20
+ bin_ids: torch.Tensor,
21
+ weights: torch.Tensor,
22
+ bins: torch.Tensor,
23
+ padded_bins: torch.Tensor,
24
+ top_k: int,
25
+ ):
26
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
27
+ ctx.save_for_backward(
28
+ indices,
29
+ bin_ids,
30
+ weights,
31
+ bins,
32
+ padded_bins,
33
+ *maybe_x,
34
+ )
35
+ ctx.top_k = top_k
36
+ ctx.x_shape = x.shape
37
+ return kernels.padded_scatter(
38
+ x,
39
+ indices,
40
+ bin_ids,
41
+ weights,
42
+ bins,
43
+ padded_bins,
44
+ top_k,
45
+ )
46
+
47
+ @staticmethod
48
+ @custom_bwd
49
+ def backward(ctx: Any, grad: torch.Tensor):
50
+ grad = grad.contiguous()
51
+ saved_tensors = ctx.saved_tensors
52
+
53
+ indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5]
54
+ dgrad = None
55
+ if ctx.needs_input_grad[0]:
56
+ dgrad = kernels.padded_gather(
57
+ grad,
58
+ indices,
59
+ bin_ids,
60
+ weights,
61
+ bins,
62
+ padded_bins,
63
+ ctx.top_k,
64
+ )
65
+
66
+ wgrad = None
67
+ if ctx.needs_input_grad[3]: # need wgrad
68
+ x = saved_tensors[-1]
69
+ wgrad = kernels.padded_scatter_wgrad(
70
+ x,
71
+ grad,
72
+ indices,
73
+ bin_ids,
74
+ bins,
75
+ padded_bins,
76
+ ctx.top_k,
77
+ )
78
+ return dgrad, None, None, wgrad, None, None, None, None
79
+
80
+
81
+ def padded_scatter(
82
+ x: torch.Tensor,
83
+ indices: torch.Tensor,
84
+ bin_ids: torch.Tensor,
85
+ weights: torch.Tensor,
86
+ bins: torch.Tensor,
87
+ padded_bins: torch.Tensor,
88
+ top_k: int,
89
+ ):
90
+ return PaddedScatterOp.apply(
91
+ x,
92
+ indices,
93
+ bin_ids,
94
+ weights,
95
+ bins,
96
+ padded_bins,
97
+ top_k,
98
+ )
build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import unittest
5
+
6
+ import torch
7
+ from absl.testing import parameterized
8
+
9
+ from .. import benchmark_util, ops
10
+
11
+ _PADDED_SCATTER_BENCHMARK = (
12
+ # dMoE-Medium, 8-way EMP.
13
+ (1024 * 16, 1024, 8, 4),
14
+ # dMoE-Medium, post-all-to-all.
15
+ (1024 * 16 * 4, 1024, 8, 1),
16
+ )
17
+
18
+
19
+ class PaddedScatterTest(parameterized.TestCase):
20
+
21
+ @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK)
22
+ def testPaddedScatter(self, sl, hs, ne, top_k):
23
+ # Create the data and indices.
24
+ x = torch.randn((sl, hs)).cuda().half()
25
+
26
+ # Randomly assign tokens to experts.
27
+ top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
28
+ bin_ids, indices = ops.sort(top_expert)
29
+ tokens_per_expert = ops.histogram(top_expert, ne)
30
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
31
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
32
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
33
+
34
+ # Sample weights for the scatter reduce.
35
+ weights = torch.rand((sl * top_k,)).cuda().half()
36
+
37
+ # Gather the data to prepare for backwards.
38
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
39
+
40
+ def benchmark():
41
+ return ops.padded_scatter(
42
+ x,
43
+ indices,
44
+ bin_ids,
45
+ weights,
46
+ bins,
47
+ padded_bins,
48
+ top_k,
49
+ )
50
+
51
+ time, std = benchmark_util.benchmark_function(benchmark)
52
+ benchmark_util.log_benchmark(
53
+ 'Padded Scatter',
54
+ {
55
+ 'sequence_length': sl,
56
+ 'hidden_size': hs,
57
+ 'num_experts': ne,
58
+ 'top_k': top_k,
59
+ },
60
+ time,
61
+ std,
62
+ )
63
+
64
+
65
+ if __name__ == '__main__':
66
+ unittest.main()
build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/permute_benchmark.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import unittest
5
+
6
+ import torch
7
+ from absl.testing import parameterized
8
+
9
+ from .. import benchmark_util, ops
10
+
11
+ _PERMUTE_TESTS = (
12
+ (16384, 768, 2),
13
+ (16384, 768, 4),
14
+ (16384, 768, 8),
15
+ (16384, 768, 16),
16
+ (16384, 768, 32),
17
+ (16384, 768, 64),
18
+ (16384, 768, 128),
19
+ (16384 * 8, 768, 2),
20
+ (16384 * 8, 768, 4),
21
+ (16384 * 8, 768, 8),
22
+ (16384 * 8, 768, 16),
23
+ (16384 * 8, 768, 32),
24
+ (16384 * 8, 768, 64),
25
+ (16384 * 8, 768, 128),
26
+ )
27
+
28
+
29
+ class PermuteBenchmark(parameterized.TestCase):
30
+
31
+ @parameterized.parameters(*_PERMUTE_TESTS)
32
+ def testBinnedGather(self, sl, hs, ne):
33
+ # NOTE: Capacity factor == 1.
34
+ ec = sl // ne
35
+
36
+ # Create the data and indices.
37
+ x = torch.randn((sl, hs)).cuda().half()
38
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
39
+ bin_ids, indices = ops.sort(top_expert)
40
+ tokens_per_expert = ops.histogram(indices, ne)
41
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
42
+
43
+ def benchmark():
44
+ return ops.binned_gather(x, indices, bins, ec)
45
+
46
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
47
+ arguments = {
48
+ 'sequence_length': sl,
49
+ 'hidden_size': hs,
50
+ 'num_experts': ne,
51
+ }
52
+ benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t)
53
+
54
+ @parameterized.parameters(*_PERMUTE_TESTS)
55
+ def testBinnedScatter(self, sl, hs, ne):
56
+ # NOTE: Capacity factor == 1.
57
+ ec = sl // ne
58
+
59
+ # Create the data and indices.
60
+ x = torch.randn((sl, hs)).cuda().half()
61
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
62
+ bin_ids, indices = ops.sort(top_expert)
63
+ tokens_per_expert = ops.histogram(indices, ne)
64
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
65
+ x = ops.binned_gather(x, indices, bins, ec)
66
+
67
+ def benchmark():
68
+ return ops.binned_scatter(x, indices, bins)
69
+
70
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
71
+ arguments = {
72
+ 'sequence_length': sl,
73
+ 'hidden_size': hs,
74
+ 'num_experts': ne,
75
+ }
76
+ benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t)
77
+
78
+ @parameterized.parameters(*_PERMUTE_TESTS)
79
+ def testPaddedGather(self, sl, hs, ne):
80
+ # Create the data and indices.
81
+ x = torch.randn((sl, hs)).cuda().half()
82
+
83
+ # Randomly assign tokens to experts.
84
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
85
+ bin_ids, indices = ops.sort(top_expert)
86
+ tokens_per_expert = ops.histogram(top_expert, ne)
87
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
88
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
89
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
90
+
91
+ def benchmark():
92
+ return ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
93
+
94
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
95
+ arguments = {
96
+ 'sequence_length': sl,
97
+ 'hidden_size': hs,
98
+ 'num_experts': ne,
99
+ }
100
+ benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t)
101
+
102
+ @parameterized.parameters(*_PERMUTE_TESTS)
103
+ def testPaddedScatter(self, sl, hs, ne):
104
+ # Create the data and indices.
105
+ x = torch.randn((sl, hs)).cuda().half()
106
+
107
+ # Randomly assign tokens to experts.
108
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
109
+ bin_ids, indices = ops.sort(top_expert)
110
+ tokens_per_expert = ops.histogram(top_expert, ne)
111
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
112
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
113
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
114
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
115
+
116
+ def benchmark():
117
+ return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins)
118
+
119
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
120
+ arguments = {
121
+ 'sequence_length': sl,
122
+ 'hidden_size': hs,
123
+ 'num_experts': ne,
124
+ }
125
+ benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t)
126
+
127
+ @parameterized.parameters(*_PERMUTE_TESTS)
128
+ def testCopy(self, sl, hs, ne):
129
+ # NOTE: Capacity factor == 1.
130
+ # ec = sl // ne
131
+
132
+ # Create the data and indices.
133
+ x = torch.randn((sl, hs)).cuda().half()
134
+ y = x.clone()
135
+
136
+ def benchmark():
137
+ return y.copy_(x)
138
+
139
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
140
+ arguments = {
141
+ 'sequence_length': sl,
142
+ 'hidden_size': hs,
143
+ 'num_experts': ne,
144
+ }
145
+ benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t)
146
+
147
+
148
+ if __name__ == '__main__':
149
+ unittest.main()
build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/repeat.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import torch
5
+
6
+
7
+ def repeat(x: torch.Tensor, tiling: torch.Size):
8
+ if all((t == 1 for t in tiling)):
9
+ return x
10
+ return x.repeat(*tiling)
build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/replicate.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Any
5
+
6
+ # NOTE: Torch needs to be imported before the custom
7
+ # extensions. Otherwise libc10.so cannot be found.
8
+ import torch
9
+
10
+ # Wrap this in a try-block with better error message and
11
+ # instructions for building the c++ operations.
12
+ try:
13
+ from .._ops import ops # type: ignore
14
+ except ModuleNotFoundError as e:
15
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
16
+
17
+
18
+ # Autograd wrapper for replicate kernel.
19
+ class ReplicateOp(torch.autograd.Function):
20
+
21
+ @staticmethod
22
+ def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int):
23
+ ctx.save_for_backward(bins)
24
+ out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device)
25
+ ops.replicate_forward(x, bins, out)
26
+ return out
27
+
28
+ @staticmethod
29
+ def backward(ctx: Any, grad: torch.Tensor):
30
+ bins, = ctx.saved_tensors
31
+ out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device)
32
+ ops.replicate_backward(grad, bins, out)
33
+ return out, None, None
34
+
35
+
36
+ replicate = ReplicateOp.apply
build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/round_up.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import torch
5
+
6
+
7
+ def round_up(x: torch.Tensor, value: int):
8
+ assert isinstance(value, int)
9
+ assert x.dtype == torch.int32
10
+
11
+ # TODO(tgale): If this becomes and issue
12
+ # do this in a custom kernel. We only expect
13
+ # to use this on arrays of less than 1k elements.
14
+ return torch.div(x + (value - 1), value, rounding_mode='trunc') * value